Skip to content

Commit

Permalink
added params override
Browse files Browse the repository at this point in the history
  • Loading branch information
Motor committed Oct 1, 2024
1 parent 64ae25b commit e187fe0
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 7 deletions.
1 change: 1 addition & 0 deletions common/ctxkey/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
OriginalModel = "original_model"
Group = "group"
ModelMapping = "model_mapping"
ParamsOverride = "params_override"
ChannelName = "channel_name"
TokenId = "token_id"
TokenName = "token_name"
Expand Down
1 change: 1 addition & 0 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name)
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.ParamsOverride, channel.GetParamsOverride())

Check warning on line 65 in middleware/distributor.go

View check run for this annotation

Codecov / codecov/patch

middleware/distributor.go#L65

Added line #L65 was not covered by tests
c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
Expand Down
15 changes: 15 additions & 0 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
ParamsOverride *string `json:"default_params_override" gorm:"type:text;default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
}
Expand Down Expand Up @@ -123,6 +124,20 @@ func (channel *Channel) GetModelMapping() map[string]string {
return modelMapping
}

func (channel *Channel) GetParamsOverride() map[string]map[string]interface{} {
if channel.ParamsOverride == nil || *channel.ParamsOverride == "" || *channel.ParamsOverride == "{}" {
return nil

Check warning on line 129 in model/channel.go

View check run for this annotation

Codecov / codecov/patch

model/channel.go#L127-L129

Added lines #L127 - L129 were not covered by tests
}
paramsOverride := make(map[string]map[string]interface{})
err := json.Unmarshal([]byte(*channel.ParamsOverride), &paramsOverride)
if err != nil {
logger.SysError(fmt.Sprintf("failed to unmarshal params override for channel %d, error: %s", channel.Id, err.Error()))
return nil

Check warning on line 135 in model/channel.go

View check run for this annotation

Codecov / codecov/patch

model/channel.go#L131-L135

Added lines #L131 - L135 were not covered by tests
}
return paramsOverride

Check warning on line 137 in model/channel.go

View check run for this annotation

Codecov / codecov/patch

model/channel.go#L137

Added line #L137 was not covered by tests
}


func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
Expand Down
104 changes: 97 additions & 7 deletions relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"io"
"net/http"
"io/ioutil"
"context"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
Expand All @@ -23,13 +25,34 @@ import (
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
meta.IsStream = textRequest.Stream

// Read the original request body
bodyBytes, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
logger.Errorf(ctx, "Failed to read request body: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)

Check warning on line 33 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L30-L33

Added lines #L30 - L33 were not covered by tests
}

// Restore the request body for `getAndValidateTextRequest`
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))

Check warning on line 37 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L37

Added line #L37 was not covered by tests

// Call `getAndValidateTextRequest`
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)

Check warning on line 43 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L40-L43

Added lines #L40 - L43 were not covered by tests
}
meta.IsStream = textRequest.Stream

Check warning on line 45 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L45

Added line #L45 was not covered by tests

// Parse the request body into a map
var rawRequest map[string]interface{}
if err := json.Unmarshal(bodyBytes, &rawRequest); err != nil {
logger.Errorf(ctx, "Failed to parse request body into map: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest)

Check warning on line 51 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L48-L51

Added lines #L48 - L51 were not covered by tests
}

// Apply parameter overrides
applyParameterOverrides(ctx, meta, textRequest, rawRequest)

Check warning on line 55 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L55

Added line #L55 was not covered by tests

// map model name
meta.OriginModelName = textRequest.Model
Expand Down Expand Up @@ -105,3 +128,70 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}

func applyParameterOverrides(ctx context.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, rawRequest map[string]interface{}) {
if meta.ParamsOverride != nil {
modelName := meta.OriginModelName
if overrideParams, exists := meta.ParamsOverride[modelName]; exists {
logger.Infof(ctx, "Applying parameter overrides for model %s on channel %d", modelName, meta.ChannelId)
for key, value := range overrideParams {
if _, userSpecified := rawRequest[key]; !userSpecified {

Check warning on line 138 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L132-L138

Added lines #L132 - L138 were not covered by tests
// Apply the override since the user didn't specify this parameter
switch key {
case "temperature":
if v, ok := value.(float64); ok {
textRequest.Temperature = v
} else if v, ok := value.(int); ok {
textRequest.Temperature = float64(v)

Check warning on line 145 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L140-L145

Added lines #L140 - L145 were not covered by tests
}
case "max_tokens":
if v, ok := value.(float64); ok {
textRequest.MaxTokens = int(v)
} else if v, ok := value.(int); ok {
textRequest.MaxTokens = v

Check warning on line 151 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L147-L151

Added lines #L147 - L151 were not covered by tests
}
case "top_p":
if v, ok := value.(float64); ok {
textRequest.TopP = v
} else if v, ok := value.(int); ok {
textRequest.TopP = float64(v)

Check warning on line 157 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L153-L157

Added lines #L153 - L157 were not covered by tests
}
case "frequency_penalty":
if v, ok := value.(float64); ok {
textRequest.FrequencyPenalty = v
} else if v, ok := value.(int); ok {
textRequest.FrequencyPenalty = float64(v)

Check warning on line 163 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L159-L163

Added lines #L159 - L163 were not covered by tests
}
case "presence_penalty":
if v, ok := value.(float64); ok {
textRequest.PresencePenalty = v
} else if v, ok := value.(int); ok {
textRequest.PresencePenalty = float64(v)

Check warning on line 169 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L165-L169

Added lines #L165 - L169 were not covered by tests
}
case "stop":
textRequest.Stop = value
case "n":
if v, ok := value.(float64); ok {
textRequest.N = int(v)
} else if v, ok := value.(int); ok {
textRequest.N = v

Check warning on line 177 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L171-L177

Added lines #L171 - L177 were not covered by tests
}
case "stream":
if v, ok := value.(bool); ok {
textRequest.Stream = v

Check warning on line 181 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L179-L181

Added lines #L179 - L181 were not covered by tests
}
case "num_ctx":
if v, ok := value.(float64); ok {
textRequest.NumCtx = int(v)
} else if v, ok := value.(int); ok {
textRequest.NumCtx = v

Check warning on line 187 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L183-L187

Added lines #L183 - L187 were not covered by tests
}
// Handle other parameters as needed
default:
logger.Warnf(ctx, "Unknown parameter override key: %s", key)

Check warning on line 191 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L190-L191

Added lines #L190 - L191 were not covered by tests
}
}
}
}
}
}
6 changes: 6 additions & 0 deletions relay/meta/relay_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Meta struct {
UserId int
Group string
ModelMapping map[string]string
ParamsOverride map[string]map[string]interface{}
// BaseURL is the proxy url set in the channel config
BaseURL string
APIKey string
Expand Down Expand Up @@ -47,6 +48,11 @@ func GetByContext(c *gin.Context) *Meta {
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(),
}
// Retrieve ParamsOverride
paramsOverride, exists := c.Get(ctxkey.ParamsOverride)
if exists && paramsOverride != nil {
meta.ParamsOverride = paramsOverride.(map[string]map[string]interface{})

Check warning on line 54 in relay/meta/relay_meta.go

View check run for this annotation

Codecov / codecov/patch

relay/meta/relay_meta.go#L52-L54

Added lines #L52 - L54 were not covered by tests
}
cfg, ok := c.Get(ctxkey.Config)
if ok {
meta.Config = cfg.(model.ChannelConfig)
Expand Down

0 comments on commit e187fe0

Please sign in to comment.