diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 90556b3af6..ba8c4595b3 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -14,6 +14,7 @@ const ( OriginalModel = "original_model" Group = "group" ModelMapping = "model_mapping" + ParamsOverride = "params_override" ChannelName = "channel_name" TokenId = "token_id" TokenName = "token_name" diff --git a/middleware/distributor.go b/middleware/distributor.go index e2f7511075..a2d8351f1d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -62,6 +62,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) + c.Set(ctxkey.ParamsOverride, channel.GetParamsOverride()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set(ctxkey.BaseURL, channel.GetBaseURL()) diff --git a/model/channel.go b/model/channel.go index 759dfd4fed..f9e322a5d0 100644 --- a/model/channel.go +++ b/model/channel.go @@ -35,6 +35,7 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + ParamsOverride *string `json:"default_params_override" gorm:"type:text;default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Config string `json:"config"` } @@ -123,6 +124,20 @@ func (channel *Channel) GetModelMapping() map[string]string { return modelMapping } +func (channel *Channel) GetParamsOverride() map[string]map[string]interface{} { + if channel.ParamsOverride == nil || *channel.ParamsOverride == "" || *channel.ParamsOverride == "{}" { + return nil + } + paramsOverride := make(map[string]map[string]interface{}) + err := json.Unmarshal([]byte(*channel.ParamsOverride), ¶msOverride) + if err != nil { + logger.SysError(fmt.Sprintf("failed to unmarshal params override for channel %d, error: %s", channel.Id, err.Error())) + return nil + } + return paramsOverride +} + + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error diff --git a/relay/controller/text.go b/relay/controller/text.go index 52ee9949ae..4d74819f21 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "net/http" + "io/ioutil" + "context" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" @@ -23,13 +25,34 @@ import ( func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) - // get & validate textRequest - textRequest, err := getAndValidateTextRequest(c, meta.Mode) - if err != nil { - logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) - return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) - } - meta.IsStream = textRequest.Stream + + // Read the original request body + bodyBytes, err := ioutil.ReadAll(c.Request.Body) + if err != nil { + logger.Errorf(ctx, "Failed to read request body: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) + } + + // Restore the request body for `getAndValidateTextRequest` + c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + + // Call `getAndValidateTextRequest` + textRequest, err := getAndValidateTextRequest(c, meta.Mode) + if err != nil { + logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) + } + meta.IsStream = textRequest.Stream + + // Parse the request body into a map + var rawRequest map[string]interface{} + if err := json.Unmarshal(bodyBytes, &rawRequest); err != nil { + logger.Errorf(ctx, "Failed to parse request body into map: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) + } + + // Apply parameter overrides + applyParameterOverrides(ctx, meta, textRequest, rawRequest) // map model name meta.OriginModelName = textRequest.Model @@ -105,3 +128,70 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO requestBody = bytes.NewBuffer(jsonData) return requestBody, nil } + +func applyParameterOverrides(ctx context.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, rawRequest map[string]interface{}) { + if meta.ParamsOverride != nil { + modelName := meta.OriginModelName + if overrideParams, exists := meta.ParamsOverride[modelName]; exists { + logger.Infof(ctx, "Applying parameter overrides for model %s on channel %d", modelName, meta.ChannelId) + for key, value := range overrideParams { + if _, userSpecified := rawRequest[key]; !userSpecified { + // Apply the override since the user didn't specify this parameter + switch key { + case "temperature": + if v, ok := value.(float64); ok { + textRequest.Temperature = v + } else if v, ok := value.(int); ok { + textRequest.Temperature = float64(v) + } + case "max_tokens": + if v, ok := value.(float64); ok { + textRequest.MaxTokens = int(v) + } else if v, ok := value.(int); ok { + textRequest.MaxTokens = v + } + case "top_p": + if v, ok := value.(float64); ok { + textRequest.TopP = v + } else if v, ok := value.(int); ok { + textRequest.TopP = float64(v) + } + case "frequency_penalty": + if v, ok := value.(float64); ok { + textRequest.FrequencyPenalty = v + } else if v, ok := value.(int); ok { + textRequest.FrequencyPenalty = float64(v) + } + case "presence_penalty": + if v, ok := value.(float64); ok { + textRequest.PresencePenalty = v + } else if v, ok := value.(int); ok { + textRequest.PresencePenalty = float64(v) + } + case "stop": + textRequest.Stop = value + case "n": + if v, ok := value.(float64); ok { + textRequest.N = int(v) + } else if v, ok := value.(int); ok { + textRequest.N = v + } + case "stream": + if v, ok := value.(bool); ok { + textRequest.Stream = v + } + case "num_ctx": + if v, ok := value.(float64); ok { + textRequest.NumCtx = int(v) + } else if v, ok := value.(int); ok { + textRequest.NumCtx = v + } + // Handle other parameters as needed + default: + logger.Warnf(ctx, "Unknown parameter override key: %s", key) + } + } + } + } + } +} \ No newline at end of file diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index b1761e9a7c..e7e051c6f5 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -18,6 +18,7 @@ type Meta struct { UserId int Group string ModelMapping map[string]string + ParamsOverride map[string]map[string]interface{} // BaseURL is the proxy url set in the channel config BaseURL string APIKey string @@ -47,6 +48,11 @@ func GetByContext(c *gin.Context) *Meta { APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), } + // Retrieve ParamsOverride + paramsOverride, exists := c.Get(ctxkey.ParamsOverride) + if exists && paramsOverride != nil { + meta.ParamsOverride = paramsOverride.(map[string]map[string]interface{}) + } cfg, ok := c.Get(ctxkey.Config) if ok { meta.Config = cfg.(model.ChannelConfig)