diff --git a/controller/relay-utils.go b/controller/relay-utils.go new file mode 100644 index 0000000000..a202e69b57 --- /dev/null +++ b/controller/relay-utils.go @@ -0,0 +1,61 @@ +package controller + +import ( + "fmt" + "github.com/pkoukk/tiktoken-go" + "one-api/common" + "strings" +) + +var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} + +func getTokenEncoder(model string) *tiktoken.Tiktoken { + if tokenEncoder, ok := tokenEncoderMap[model]; ok { + return tokenEncoder + } + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error())) + } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder +} + +func countTokenMessages(messages []Message, model string) int { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if strings.HasPrefix(model, "gpt-3.5") { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else if strings.HasPrefix(model, "gpt-4") { + tokensPerMessage = 3 + tokensPerName = 1 + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil)) + tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil)) + if message.Name != "" { + tokenNum += tokensPerName + tokenNum += len(tokenEncoder.Encode(message.Name, nil, nil)) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum +} + +func countTokenText(text string, model string) int { + tokenEncoder := getTokenEncoder(model) + token := tokenEncoder.Encode(text, nil, nil) + return len(token) +} diff --git a/controller/relay.go b/controller/relay.go index bc350f0d60..d84a741c1d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" @@ -17,6 +16,7 @@ import ( type Message struct { Role string `json:"role"` Content string `json:"content"` + Name string `json:"name"` } type ChatRequest struct { @@ -65,13 +65,6 @@ type StreamResponse struct { } `json:"choices"` } -var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base") - -func countToken(text string) int { - token := tokenEncoder.Encode(text, nil, nil) - return len(token) -} - func Relay(c *gin.Context) { err := relayHelper(c) if err != nil { @@ -149,11 +142,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0314") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) } - var promptText string - for _, message := range textRequest.Messages { - promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content) - } - promptTokens := countToken(promptText) + 3 + + promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model) preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens @@ -206,8 +196,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { completionRatio = 2 } if isStream { - completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText) - quota = promptTokens + countToken(completionText)*completionRatio + responseTokens := countTokenText(streamResponseText, textRequest.Model) + quota = promptTokens + responseTokens*completionRatio } else { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio }