From 3ab6bd3ed3c29518b59a378878009adcaead5837 Mon Sep 17 00:00:00 2001 From: Eduardo Khattab Date: Wed, 19 Apr 2023 14:58:52 -0300 Subject: [PATCH 1/2] fix message too large exception --- chatservice/Makefile | 6 +- chatservice/docker-compose.yaml | 6 +- chatservice/go.mod | 4 ++ chatservice/go.sum | 1 + chatservice/internal/domain/entity/chat.go | 16 ++++- .../internal/domain/entity/chat_test.go | 69 +++++++++++++++++++ chatservice/internal/domain/entity/message.go | 14 +++- .../internal/infra/grpc/server/server.go | 12 +++- .../internal/infra/web/chat_gpt_handler.go | 6 ++ .../usecase/chatcompletion/completion.go | 7 +- .../chatcompletionstream/completion.go | 7 +- 11 files changed, 134 insertions(+), 14 deletions(-) create mode 100644 chatservice/internal/domain/entity/chat_test.go diff --git a/chatservice/Makefile b/chatservice/Makefile index d5ca540..c7107eb 100644 --- a/chatservice/Makefile +++ b/chatservice/Makefile @@ -10,4 +10,8 @@ migratedown: grpc: protoc --go_out=. --go-grpc_out=. proto/chat.proto --experimental_allow_proto3_optional -.PHONY: migrate createmigration migratedown grpc \ No newline at end of file +tests: + - docker-compose up -d && docker-compose exec chatservice go test ./... + docker-compose down + +.PHONY: migrate createmigration migratedown grpc tests diff --git a/chatservice/docker-compose.yaml b/chatservice/docker-compose.yaml index 14a507b..af8f491 100644 --- a/chatservice/docker-compose.yaml +++ b/chatservice/docker-compose.yaml @@ -4,6 +4,8 @@ services: chatservice: build: . container_name: chatservice_app + depends_on: + - mysql volumes: - .:/go/src ports: @@ -19,6 +21,6 @@ services: MYSQL_DATABASE: chat_test MYSQL_PASSWORD: root ports: - - 3306:3306 + - 3306:3306 volumes: - - .docker/mysql:/var/lib/mysql \ No newline at end of file + - .docker/mysql:/var/lib/mysql diff --git a/chatservice/go.mod b/chatservice/go.mod index e8bbdb7..cbe91f8 100644 --- a/chatservice/go.mod +++ b/chatservice/go.mod @@ -13,21 +13,25 @@ require ( github.com/go-sql-driver/mysql v1.7.0 github.com/sashabaranov/go-openai v1.5.8 github.com/spf13/viper v1.15.0 + github.com/stretchr/testify v1.8.1 google.golang.org/grpc v1.52.0 google.golang.org/protobuf v1.28.1 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/afero v1.9.3 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect golang.org/x/net v0.4.0 // indirect golang.org/x/sys v0.3.0 // indirect diff --git a/chatservice/go.sum b/chatservice/go.sum index f68453b..8614a55 100644 --- a/chatservice/go.sum +++ b/chatservice/go.sum @@ -172,6 +172,7 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/chatservice/internal/domain/entity/chat.go b/chatservice/internal/domain/entity/chat.go index 29e26b6..8a4b666 100644 --- a/chatservice/internal/domain/entity/chat.go +++ b/chatservice/internal/domain/entity/chat.go @@ -37,7 +37,11 @@ func NewChat(userID string, initialSystemMessage *Message, chatConfig *ChatConfi Config: chatConfig, TokenUsage: 0, } - chat.AddMessage(initialSystemMessage) + + err := chat.AddMessage(initialSystemMessage) + if err != nil { + return nil, err + } if err := chat.Validate(); err != nil { return nil, err @@ -63,8 +67,16 @@ func (c *Chat) AddMessage(m *Message) error { if c.Status == "ended" { return errors.New("chat is ended. no more messages allowed") } + + messageTotalTokens := m.GetQtdTokens() + modelMaxTokens := c.Config.Model.GetMaxTokens() + + if messageTotalTokens > modelMaxTokens { + return errors.New("message too large") + } + for { - if c.Config.Model.GetMaxTokens() >= m.GetQtdTokens()+c.TokenUsage { + if modelMaxTokens >= messageTotalTokens+c.TokenUsage { c.Messages = append(c.Messages, m) c.RefreshTokenUsage() break diff --git a/chatservice/internal/domain/entity/chat_test.go b/chatservice/internal/domain/entity/chat_test.go new file mode 100644 index 0000000..4233928 --- /dev/null +++ b/chatservice/internal/domain/entity/chat_test.go @@ -0,0 +1,69 @@ +package entity_test + +import ( + "testing" + + "github.com/devfullcycle/fclx/chatservice/internal/domain/entity" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + modelName = "gpt-3.5-turbo" + modelMaxTokens = 10 + model = entity.NewModel(modelName, modelMaxTokens) + chatConfig = &entity.ChatConfig{Model: model, Temperature: 1} + + userID = "510ae7e0-4122-49e3-9384-68fa573c2afc" + systemRole = "system" + userRole = "user" + + basicMessageContent = "BasicMessageContent" + messageTooLargeContent = "MessageTooLargeContent" + + mockTikToken = &MockTikToken{} +) + +type MockTikToken struct { + mock.Mock +} + +func (m *MockTikToken) CountTokens(model, prompt string) int { + args := m.Called(model, prompt) + return args.Int(0) +} + +func TestAddMessageShouldNotThrowErrorWhenMessageIsTooLarge(t *testing.T) { + mockTikToken.On("CountTokens", model.Name, basicMessageContent).Return(2) + + initialMessage := newMessage(systemRole, basicMessageContent, model) + + chat, err := entity.NewChat(userID, initialMessage, chatConfig) + if err != nil { + t.Fatal("error creating chat") + } + + mockTikToken.On("CountTokens", model.Name, messageTooLargeContent).Return(99) + messageTooLarge := newMessage(userRole, messageTooLargeContent, model) + + err = chat.AddMessage(messageTooLarge) + errMessage := "message too large" + assert.EqualErrorf(t, err, errMessage, "Error should be: %v, got: %v", errMessage, err) +} + +func TestNewChatShouldNotThrowErrorWhenInitialMessageIsTooLarge(t *testing.T) { + mockTikToken.On("CountTokens", model.Name, messageTooLargeContent).Return(99) + + initialMessage := newMessage(systemRole, messageTooLargeContent, model) + + _, err := entity.NewChat(userID, initialMessage, chatConfig) + errMessage := "message too large" + assert.EqualErrorf(t, err, errMessage, "Error should be: %v, got: %v", errMessage, err) +} + +func newMessage(role, content string, model *entity.Model) *entity.Message { + message, _ := entity.NewMessage(role, content, mockTikToken, model) + + return message +} diff --git a/chatservice/internal/domain/entity/message.go b/chatservice/internal/domain/entity/message.go index 7a200cd..b9ae463 100644 --- a/chatservice/internal/domain/entity/message.go +++ b/chatservice/internal/domain/entity/message.go @@ -17,8 +17,18 @@ type Message struct { CreatedAt time.Time } -func NewMessage(role, content string, model *Model) (*Message, error) { - totalTokens := tiktoken_go.CountTokens(model.GetModelName(), content) +type TikToken interface { + CountTokens(model, prompt string) int +} + +type TikTokenImpl struct{} + +func (t *TikTokenImpl) CountTokens(model, prompt string) int { + return tiktoken_go.CountTokens(model, prompt) +} + +func NewMessage(role, content string, tikToken TikToken, model *Model) (*Message, error) { + totalTokens := tikToken.CountTokens(model.GetModelName(), content) msg := &Message{ ID: uuid.New().String(), Role: role, diff --git a/chatservice/internal/infra/grpc/server/server.go b/chatservice/internal/infra/grpc/server/server.go index a8b7eaa..53191db 100644 --- a/chatservice/internal/infra/grpc/server/server.go +++ b/chatservice/internal/infra/grpc/server/server.go @@ -2,6 +2,7 @@ package server import ( "net" + "strings" "github.com/devfullcycle/fclx/chatservice/internal/infra/grpc/pb" "github.com/devfullcycle/fclx/chatservice/internal/infra/grpc/service" @@ -50,7 +51,16 @@ func (g *GRPCServer) AuthInterceptor(srv interface{}, ss grpc.ServerStream, info return status.Error(codes.Unauthenticated, "authorization token is invalid") } - return handler(srv, ss) + err := handler(srv, ss) + if err != nil { + if strings.Contains(err.Error(), "message too large") { + return status.Error(codes.InvalidArgument, err.Error()) + } + + return status.Errorf(codes.Internal, "internal error %v", err) + } + + return nil } func (g *GRPCServer) Start() error { diff --git a/chatservice/internal/infra/web/chat_gpt_handler.go b/chatservice/internal/infra/web/chat_gpt_handler.go index cb2dad4..e3cc90c 100644 --- a/chatservice/internal/infra/web/chat_gpt_handler.go +++ b/chatservice/internal/infra/web/chat_gpt_handler.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io/ioutil" "net/http" + "strings" "github.com/devfullcycle/fclx/chatservice/internal/usecase/chatcompletion" ) @@ -54,6 +55,11 @@ func (h *WebChatGPTHandler) Handle(w http.ResponseWriter, r *http.Request) { result, err := h.CompletionUseCase.Execute(r.Context(), dto) if err != nil { + if strings.Contains(err.Error(), "message too large") { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/chatservice/internal/usecase/chatcompletion/completion.go b/chatservice/internal/usecase/chatcompletion/completion.go index 42dd8d6..fbf083f 100644 --- a/chatservice/internal/usecase/chatcompletion/completion.go +++ b/chatservice/internal/usecase/chatcompletion/completion.go @@ -64,7 +64,8 @@ func (uc *ChatCompletionUseCase) Execute(ctx context.Context, input ChatCompleti } } - userMessage, err := entity.NewMessage("user", input.UserMessage, chat.Config.Model) + tikToken := &entity.TikTokenImpl{} + userMessage, err := entity.NewMessage("user", input.UserMessage, tikToken, chat.Config.Model) if err != nil { return nil, errors.New("error creating new message: " + err.Error()) } @@ -98,7 +99,7 @@ func (uc *ChatCompletionUseCase) Execute(ctx context.Context, input ChatCompleti return nil, errors.New("error openai: " + err.Error()) } - assistant, err := entity.NewMessage("assistant", resp.Choices[0].Message.Content, chat.Config.Model) + assistant, err := entity.NewMessage("assistant", resp.Choices[0].Message.Content, tikToken, chat.Config.Model) if err != nil { return nil, err } @@ -134,7 +135,7 @@ func createNewChat(input ChatCompletionInputDTO) (*entity.Chat, error) { Model: model, } - initialMessage, err := entity.NewMessage("system", input.Config.InitialSystemMessage, model) + initialMessage, err := entity.NewMessage("system", input.Config.InitialSystemMessage, &entity.TikTokenImpl{}, model) if err != nil { return nil, errors.New("error creating initial message: " + err.Error()) } diff --git a/chatservice/internal/usecase/chatcompletionstream/completion.go b/chatservice/internal/usecase/chatcompletionstream/completion.go index 9d5dcec..3948dc4 100644 --- a/chatservice/internal/usecase/chatcompletionstream/completion.go +++ b/chatservice/internal/usecase/chatcompletionstream/completion.go @@ -68,7 +68,8 @@ func (uc *ChatCompletionUseCase) Execute(ctx context.Context, input ChatCompleti } } - userMessage, err := entity.NewMessage("user", input.UserMessage, chat.Config.Model) + tikToken := &entity.TikTokenImpl{} + userMessage, err := entity.NewMessage("user", input.UserMessage, tikToken, chat.Config.Model) if err != nil { return nil, errors.New("error creating new message: " + err.Error()) } @@ -122,7 +123,7 @@ func (uc *ChatCompletionUseCase) Execute(ctx context.Context, input ChatCompleti uc.Stream <- r } - assistant, err := entity.NewMessage("assistant", fullResponse.String(), chat.Config.Model) + assistant, err := entity.NewMessage("assistant", fullResponse.String(), tikToken, chat.Config.Model) if err != nil { return nil, err } @@ -158,7 +159,7 @@ func createNewChat(input ChatCompletionInputDTO) (*entity.Chat, error) { Model: model, } - initialMessage, err := entity.NewMessage("system", input.Config.InitialSystemMessage, model) + initialMessage, err := entity.NewMessage("system", input.Config.InitialSystemMessage, &entity.TikTokenImpl{}, model) if err != nil { return nil, errors.New("error creating initial message: " + err.Error()) } From ebcb51519796f1abeefde8b678ee3a7616ebb868 Mon Sep 17 00:00:00 2001 From: Eduardo Khattab Date: Wed, 19 Apr 2023 16:04:11 -0300 Subject: [PATCH 2/2] remove error handling on test --- chatservice/internal/domain/entity/chat_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/chatservice/internal/domain/entity/chat_test.go b/chatservice/internal/domain/entity/chat_test.go index 4233928..46fb7d5 100644 --- a/chatservice/internal/domain/entity/chat_test.go +++ b/chatservice/internal/domain/entity/chat_test.go @@ -39,15 +39,12 @@ func TestAddMessageShouldNotThrowErrorWhenMessageIsTooLarge(t *testing.T) { initialMessage := newMessage(systemRole, basicMessageContent, model) - chat, err := entity.NewChat(userID, initialMessage, chatConfig) - if err != nil { - t.Fatal("error creating chat") - } + chat, _ := entity.NewChat(userID, initialMessage, chatConfig) mockTikToken.On("CountTokens", model.Name, messageTooLargeContent).Return(99) messageTooLarge := newMessage(userRole, messageTooLargeContent, model) - err = chat.AddMessage(messageTooLarge) + err := chat.AddMessage(messageTooLarge) errMessage := "message too large" assert.EqualErrorf(t, err, errMessage, "Error should be: %v, got: %v", errMessage, err) }