Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: image edits/inpaiting 支持 replicate 的 flux remix #1986

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name: CI

# This setup assumes that you run the unit tests with code coverage in the same
# workflow that will also print the coverage report as comment to the pull request.
# workflow that will also print the coverage report as comment to the pull request.
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
# when new code is pushed to the branch of the pull request. In addition, you also
# need to trigger this workflow when new code is pushed to the main branch because
# need to trigger this workflow when new code is pushed to the main branch because
# we need to upload the code coverage results as artifact for the main branch as
# well since it will be the baseline code coverage.
#
#
# We do not want to trigger the workflow for pushes to *any* branch because this
# would trigger our jobs twice on pull requests (once from "push" event and once
# from "pull_request->synchronize")
Expand All @@ -31,7 +31,7 @@ jobs:
with:
go-version: ^1.22

# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
# in the next step as well as the next job.
- name: Test
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ logs
data
/web/node_modules
cmd.md
.env
.env
/one-api
2 changes: 2 additions & 0 deletions common/ctxkey/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ctxkey
const (
Config = "config"
Id = "id"
RequestId = "X-Oneapi-Request-Id"
Username = "username"
Role = "role"
Status = "status"
Expand All @@ -15,6 +16,7 @@ const (
Group = "group"
ModelMapping = "model_mapping"
ChannelName = "channel_name"
ContentType = "content_type"
TokenId = "token_id"
TokenName = "token_name"
BaseURL = "base_url"
Expand Down
34 changes: 22 additions & 12 deletions common/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,53 @@
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"reflect"
"strings"

"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
)

func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(ctxkey.KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
func GetRequestBody(c *gin.Context) (requestBody []byte, err error) {
if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil {
return requestBodyCache.([]byte), nil

Check warning on line 17 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L15-L17

Added lines #L15 - L17 were not covered by tests
}
requestBody, err := io.ReadAll(c.Request.Body)
requestBody, err = io.ReadAll(c.Request.Body)

Check warning on line 19 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L19

Added line #L19 was not covered by tests
if err != nil {
return nil, err
}
_ = c.Request.Body.Close()
c.Set(ctxkey.KeyRequestBody, requestBody)
return requestBody.([]byte), nil

return requestBody, nil

Check warning on line 26 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L25-L26

Added lines #L25 - L26 were not covered by tests
}

func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil {
return err
}

// check v should be a pointer
if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr {
return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v))
}

Check warning on line 38 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L36-L38

Added lines #L36 - L38 were not covered by tests

contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = json.Unmarshal(requestBody, v)

Check warning on line 42 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L42

Added line #L42 was not covered by tests
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(&v)
err = c.ShouldBind(v)

Check warning on line 45 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L45

Added line #L45 was not covered by tests
}
if err != nil {
return err
return errors.Wrap(err, "unmarshal request body failed")

Check warning on line 48 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L48

Added line #L48 was not covered by tests
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

Check warning on line 52 in common/gin.go

View check run for this annotation

Codecov / codecov/patch

common/gin.go#L51-L52

Added lines #L51 - L52 were not covered by tests
return nil
}

Expand Down
18 changes: 9 additions & 9 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
case relaymode.ImagesGenerations:
case relaymode.ImagesGenerations,
relaymode.ImagesEdits:

Check warning on line 30 in controller/relay.go

View check run for this annotation

Codecov / codecov/patch

controller/relay.go#L30

Added line #L30 was not covered by tests
err = controller.RelayImageHelper(c, relayMode)
case relaymode.AudioSpeech:
fallthrough
Expand All @@ -45,10 +46,6 @@
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt(ctxkey.Id)
bizErr := relayHelper(c, relayMode)
Expand All @@ -60,7 +57,7 @@
channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)

Check warning on line 60 in controller/relay.go

View check run for this annotation

Codecov / codecov/patch

controller/relay.go#L60

Added line #L60 was not covered by tests
requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
Expand All @@ -87,9 +84,9 @@
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName)
// BUG: bizErr is in race condition
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)

Check warning on line 87 in controller/relay.go

View check run for this annotation

Codecov / codecov/patch

controller/relay.go#L87

Added line #L87 was not covered by tests
}

if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
Expand Down Expand Up @@ -122,7 +119,10 @@
return true
}

func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
func processChannelRelayError(ctx context.Context,
userId int, channelId int, channelName string,
// FIX: err should not use a pointer to avoid data race in concurrent situations
err model.ErrorWithStatusCode) {

Check warning on line 125 in controller/relay.go

View check run for this annotation

Codecov / codecov/patch

controller/relay.go#L125

Added line #L125 was not covered by tests
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.24.0
golang.org/x/image v0.18.0
golang.org/x/sync v0.7.0
google.golang.org/api v0.187.0
gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7
Expand Down Expand Up @@ -99,7 +100,6 @@ require (
golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
Expand Down
1 change: 1 addition & 0 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
}
c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type"))

Check warning on line 67 in middleware/distributor.go

View check run for this annotation

Codecov / codecov/patch

middleware/distributor.go#L67

Added line #L67 was not covered by tests
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
Expand Down
1 change: 1 addition & 0 deletions middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"fmt"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
)
Expand Down
1 change: 1 addition & 0 deletions middleware/request-id.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"context"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
)
Expand Down
23 changes: 13 additions & 10 deletions middleware/utils.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package middleware

import (
"fmt"
"strings"

"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"strings"
)

func abortWithMessage(c *gin.Context, statusCode int, message string) {
Expand All @@ -24,28 +25,30 @@
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed")

Check warning on line 28 in middleware/utils.go

View check run for this annotation

Codecov / codecov/patch

middleware/utils.go#L28

Added line #L28 was not covered by tests
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {

switch {
case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"):

Check warning on line 32 in middleware/utils.go

View check run for this annotation

Codecov / codecov/patch

middleware/utils.go#L31-L32

Added lines #L31 - L32 were not covered by tests
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
case strings.HasSuffix(c.Request.URL.Path, "embeddings"):

Check warning on line 36 in middleware/utils.go

View check run for this annotation

Codecov / codecov/patch

middleware/utils.go#L36

Added line #L36 was not covered by tests
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"),
strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"):

Check warning on line 41 in middleware/utils.go

View check run for this annotation

Codecov / codecov/patch

middleware/utils.go#L41

Added line #L41 was not covered by tests
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"),
strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"):

Check warning on line 46 in middleware/utils.go

View check run for this annotation

Codecov / codecov/patch

middleware/utils.go#L46

Added line #L46 was not covered by tests
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}

return modelRequest.Model, nil
}

Expand Down
2 changes: 1 addition & 1 deletion monitor/manage.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
strings.Contains(lowerMessage, "credit") ||
strings.Contains(lowerMessage, "balance") ||
strings.Contains(lowerMessage, "permission denied") ||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
strings.Contains(lowerMessage, "organization has been restricted") || // groq

Check warning on line 37 in monitor/manage.go

View check run for this annotation

Codecov / codecov/patch

monitor/manage.go#L37

Added line #L37 was not covered by tests
strings.Contains(lowerMessage, "已欠费") {
return true
}
Expand Down
Binary file removed one-api
Binary file not shown.
3 changes: 3 additions & 0 deletions relay/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
Expand Down Expand Up @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &vertexai.Adaptor{}
case apitype.Proxy:
return &proxy.Adaptor{}
case apitype.Replicate:
return &replicate.Adaptor{}
}
return nil
}
9 changes: 7 additions & 2 deletions relay/adaptor/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import (
"errors"
"fmt"
"io"
"net/http"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
)

func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
Expand All @@ -27,6 +29,9 @@
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}

req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType))

Check warning on line 34 in relay/adaptor/common.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/common.go#L33-L34

Added lines #L33 - L34 were not covered by tests
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
Expand Down
6 changes: 3 additions & 3 deletions relay/adaptor/ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,

Check warning on line 35 in relay/adaptor/ollama/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/ollama/main.go#L34-L35

Added lines #L34 - L35 were not covered by tests
},
Stream: request.Stream,
}
Expand Down Expand Up @@ -122,7 +122,7 @@
for scanner.Scan() {
data := scanner.Text()
if strings.HasPrefix(data, "}") {
data = strings.TrimPrefix(data, "}") + "}"
data = strings.TrimPrefix(data, "}") + "}"

Check warning on line 125 in relay/adaptor/ollama/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/ollama/main.go#L125

Added line #L125 was not covered by tests
}

var ollamaResponse ChatResponse
Expand Down
3 changes: 3 additions & 0 deletions relay/adaptor/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
case relaymode.ImagesEdits:
err, _ = ImagesEditsHandler(c, resp)

Check warning on line 115 in relay/adaptor/openai/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/openai/adaptor.go#L114-L115

Added lines #L114 - L115 were not covered by tests
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
}

return
}

Expand Down
22 changes: 20 additions & 2 deletions relay/adaptor/openai/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,30 @@
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
)

// ImagesEditsHandler just copy response body to client
//
// https://platform.openai.com/docs/api-reference/images/createEdit
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
c.Writer.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}

Check warning on line 20 in relay/adaptor/openai/image.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/openai/image.go#L16-L20

Added lines #L16 - L20 were not covered by tests

if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
defer resp.Body.Close()

return nil, nil

Check warning on line 27 in relay/adaptor/openai/image.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/openai/image.go#L22-L27

Added lines #L22 - L27 were not covered by tests
}

func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
Expand Down
Loading
Loading