Skip to content

Commit

Permalink
减少不必要的登录检查
Browse files Browse the repository at this point in the history
  • Loading branch information
kkkunny committed Nov 23, 2024
1 parent 7f7b855 commit 379405f
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 51 deletions.
80 changes: 46 additions & 34 deletions hugchat/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

stlslices "github.com/kkkunny/stl/container/slices"
"github.com/kkkunny/stl/container/tuple"
stlerr "github.com/kkkunny/stl/error"

"github.com/kkkunny/HuggingChatAPI/config"
Expand All @@ -25,31 +26,16 @@ func NewClient(tokenProvider TokenProvider) *Client {
}
}

// CheckLogin 检查并刷新登录
func (c *Client) CheckLogin(ctx context.Context) error {
token, err := c.tokenProvider.GetToken(ctx)
if err != nil {
func (c *Client) handleUnauthorized(ctx context.Context, f func() error) error {
err := f()
if !errors.Is(err, api.ErrUnauthorized) {
return err
}
login, err := api.CheckLogin(ctx, token)
_, err = c.tokenProvider.RefreshToken(ctx)
if err != nil {
return err
}
if login {
return nil
}

token, err = c.tokenProvider.RefreshToken(ctx)
if err != nil {
return err
}
login, err = api.CheckLogin(ctx, token)
if err != nil {
return err
} else if !login {
return stlerr.Errorf("not login")
}
return nil
return f()
}

// ListModels 列出模型
Expand All @@ -58,7 +44,11 @@ func (c *Client) ListModels(ctx context.Context) ([]*dto.ModelInfo, error) {
if err != nil {
return nil, err
}
models, _, err := api.ListModelsAndConversations(ctx, token)
var models []*api.ModelInfo
err = c.handleUnauthorized(ctx, func() error {
models, _, err = api.ListModelsAndConversations(ctx, token)
return err
})
if err != nil {
return nil, err
}
Expand All @@ -73,7 +63,11 @@ func (c *Client) ListConversations(ctx context.Context) ([]*dto.SimpleConversati
if err != nil {
return nil, err
}
_, convs, err := api.ListModelsAndConversations(ctx, token)
var convs []*api.SimpleConversationInfo
err = c.handleUnauthorized(ctx, func() error {
_, convs, err = api.ListModelsAndConversations(ctx, token)
return err
})
if err != nil {
return nil, err
}
Expand All @@ -88,7 +82,11 @@ func (c *Client) ConversationInfo(ctx context.Context, convID string) (*dto.Conv
if err != nil {
return nil, err
}
conv, err := api.ConversationInfo(ctx, token, convID)
var conv *api.DetailConversationInfo
err = c.handleUnauthorized(ctx, func() error {
conv, err = api.ConversationInfo(ctx, token, convID)
return err
})
return dto.NewConversationInfoFromAPI(conv), err
}

Expand All @@ -98,15 +96,23 @@ func (c *Client) CreateConversation(ctx context.Context, model string, systemPro
if err != nil {
return nil, err
}
createResp, err := api.CreateConversation(ctx, token, &api.CreateConversationRequest{
Model: model,
PrePrompt: systemPrompt,
var createResp *api.CreateConversationResponse
err = c.handleUnauthorized(ctx, func() error {
createResp, err = api.CreateConversation(ctx, token, &api.CreateConversationRequest{
Model: model,
PrePrompt: systemPrompt,
})
return err
})
if err != nil {
return nil, err
}

info, err := api.ConversationInfoAfterCreate(ctx, token, createResp.ConversationID)
var info *api.DetailConversationInfo
err = c.handleUnauthorized(ctx, func() error {
info, err = api.ConversationInfoAfterCreate(ctx, token, createResp.ConversationID)
return err
})
return dto.NewConversationInfoFromAPI(info), err
}

Expand All @@ -116,7 +122,9 @@ func (c *Client) DeleteConversation(ctx context.Context, convID string) error {
if err != nil {
return err
}
return api.DeleteConversation(ctx, token, convID)
return c.handleUnauthorized(ctx, func() error {
return api.DeleteConversation(ctx, token, convID)
})
}

type ChatConversationParams struct {
Expand All @@ -131,12 +139,16 @@ func (c *Client) ChatConversation(ctx context.Context, convID string, params *Ch
if err != nil {
return nil, err
}
msgDataChan, err := api.ChatConversation(ctx, token, &api.ChatConversationRequest{
ConversationID: convID,
ID: params.LastMsgID,
Inputs: params.Inputs,
WebSearch: params.WebSearch,
Tools: params.Tools,
var msgDataChan chan tuple.Tuple2[string, error]
err = c.handleUnauthorized(ctx, func() error {
msgDataChan, err = api.ChatConversation(ctx, token, &api.ChatConversationRequest{
ConversationID: convID,
ID: params.LastMsgID,
Inputs: params.Inputs,
WebSearch: params.WebSearch,
Tools: params.Tools,
})
return err
})
if err != nil {
return nil, err
Expand Down
7 changes: 6 additions & 1 deletion internal/api/conversation_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"regexp"
"strings"
"time"

request "github.com/imroc/req/v3"
Expand Down Expand Up @@ -120,5 +121,9 @@ func ConversationInfo(ctx context.Context, cookies []*http.Cookie, convID string
if err != nil {
return nil, err
}
return parseDetailConversationInfo(convID, (*httpResp)["nodes"].([]any)[1].(map[string]any)["data"].([]any))
node := (*httpResp)["nodes"].([]any)[1].(map[string]any)
if node["type"] == "error" && strings.Contains(node["error"].(map[string]any)["message"].(string), "access to") {
return nil, stlerr.ErrorWrap(ErrUnauthorized)
}
return parseDetailConversationInfo(convID, node["data"].([]any))
}
3 changes: 3 additions & 0 deletions internal/api/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/kkkunny/HuggingChatAPI/config"
)

var ErrUnauthorized = errors.New("unauthorized")

// Login 登录
func Login(ctx context.Context, username string, password string) ([]*http.Cookie, error) {
cli := globalHttpClient.Clone()
Expand Down
2 changes: 2 additions & 0 deletions internal/api/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func sendDefaultHttpRequest[Result any](ctx context.Context, method string, reqH
resp, err := stlerr.ErrorWith(req.Send(method, uri))
if err != nil {
return nil, err
} else if resp.GetStatusCode() == http.StatusUnauthorized {
return nil, stlerr.ErrorWrap(ErrUnauthorized)
} else if resp.GetStatusCode() != http.StatusOK {
return nil, stlerr.Errorf("http error: code=%d, status=%s", resp.GetStatusCode(), resp.GetStatus())
}
Expand Down
5 changes: 0 additions & 5 deletions web/chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ func chatCompletions(reqCtx echo.Context) error {
return echo.ErrUnauthorized
}
cli := hugchat.NewClient(tokenProvider)
err = cli.CheckLogin(reqCtx.Request().Context())
if err != nil {
_ = config.Logger.Error(err)
return echo.ErrUnauthorized
}

var req openai.ChatCompletionRequest
if err = stlerr.ErrorWrap(reqCtx.Bind(&req)); err != nil {
Expand Down
5 changes: 0 additions & 5 deletions web/list_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ func listModels(reqCtx echo.Context) error {
return echo.ErrUnauthorized
}
cli := hugchat.NewClient(tokenProvider)
err = cli.CheckLogin(reqCtx.Request().Context())
if err != nil {
_ = config.Logger.Error(err)
return echo.ErrUnauthorized
}

models, err := cli.ListModels(reqCtx.Request().Context())
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions web/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/labstack/gommon/log"

"github.com/kkkunny/HuggingChatAPI/config"
"github.com/kkkunny/HuggingChatAPI/middleware"
)

func main() {
Expand All @@ -15,7 +14,7 @@ func main() {
svr.Logger.SetLevel(log.OFF)
svr.IPExtractor = echo.ExtractIPFromRealIPHeader()

svr.Use(middleware.ErrorHandler, middleware.Logger)
svr.Use(midErrorHandler, midLogger)

svr.GET("/v1/models", listModels)
svr.POST("/v1/chat/completions", chatCompletions)
Expand Down
4 changes: 2 additions & 2 deletions middleware/error_handler.go → web/mid_error_handler.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package middleware
package main

import (
"errors"
Expand All @@ -10,7 +10,7 @@ import (
"github.com/kkkunny/HuggingChatAPI/config"
)

func ErrorHandler(next echo.HandlerFunc) echo.HandlerFunc {
func midErrorHandler(next echo.HandlerFunc) echo.HandlerFunc {
return func(reqCtx echo.Context) (err error) {
var isPanic bool

Expand Down
4 changes: 2 additions & 2 deletions middleware/logger.go → web/mid_logger.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package middleware
package main

import (
"github.com/labstack/echo/v4"

"github.com/kkkunny/HuggingChatAPI/config"
)

func Logger(next echo.HandlerFunc) echo.HandlerFunc {
func midLogger(next echo.HandlerFunc) echo.HandlerFunc {
return func(reqCtx echo.Context) error {
_ = config.Logger.Infof("Method [%s] %s --> %s", reqCtx.Request().Method, reqCtx.RealIP(), reqCtx.Path())
return next(reqCtx)
Expand Down

0 comments on commit 379405f

Please sign in to comment.