From 64de7b28acaa50301fc9ad8372e54ca20918e7a8 Mon Sep 17 00:00:00 2001 From: kkkunny <51853085+kkkunny@users.noreply.github.com> Date: Sat, 9 Nov 2024 15:18:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E7=99=BB=E5=BD=95=20(#7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 1 + README.md | 18 ++- handler/chat_completions.go | 20 ++- handler/list_models.go | 15 +- internal/api/api.go | 96 +++++++++--- internal/api/cookie_cache.go | 66 ++++++++ internal/api/cookie_mgr.go | 61 ++++++++ internal/api/init.go | 16 ++ internal/api/login.go | 142 ++++++++++++++++++ .../{consts/const.go => config/domain.go} | 2 +- internal/config/file.go | 3 + internal/config/proxy.go | 25 ++- 12 files changed, 418 insertions(+), 47 deletions(-) create mode 100644 internal/api/cookie_cache.go create mode 100644 internal/api/cookie_mgr.go create mode 100644 internal/api/init.go create mode 100644 internal/api/login.go rename internal/{consts/const.go => config/domain.go} (77%) create mode 100644 internal/config/file.go diff --git a/Dockerfile b/Dockerfile index e3a2324..c7d4dec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,7 @@ RUN apk --no-cache add tzdata && \ ENV TZ Asia/Shanghai RUN apk add --no-cache ca-certificates && update-ca-certificates WORKDIR /app +RUN mkdir config COPY --from=builder /app/bin/* /app EXPOSE 80 ENTRYPOINT ["/app/server"] \ No newline at end of file diff --git a/README.md b/README.md index e128ff2..491acc1 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,15 @@ ## 使用说明 -1. **访问 HuggingChat 官网** - 进入 [HuggingChat官网](https://huggingface.co/chat) 并登录。 +1. **获取 Authorization** + 以下两种方式任选其一! -2. **获取 Authorization** - 从浏览器的 Cookie 中取出 `hf-chat` 的值,作为 Authorization。 + + 使用base64编码`username={你的账号名}&password={你的密码}`,假设你的账号名为usr,密码为pwd,则`username=usr&password=pwd`进行base64编码后结果为`dXNlcm5hbWU9dXNyJnBhc3N3b3JkPXB3ZA==` + + 进入 [HuggingChat官网](https://huggingface.co/chat) 并登录,从浏览器的 Cookie 中取出 `hf-chat` 的值,形如`cc43f26e-142b-409f-b228-68316s5x30a9` -3. **会话创建注意事项** + 将上面任意一种方式获得的值填入Authorization + +2. **!!注意事项!!会话创建** 由于调用创建会话的接口创建出的会话总是呈不可用状态,因此需要提前为每个模型创建好会话。 ## 调用说明 @@ -98,6 +100,12 @@ curl -X POST "http://localhost:5695/v1/chat/completions" \ docker run -d --name HuggingChat -p 5695:80 kkkunny/hugging-chat-api:latest ``` +强烈建议将/app/config映射到本地路径 + +```bash +docker run -d --name HuggingChat -p 5695:80 -v YOUR_PATH:/app/config kkkunny/hugging-chat-api:latest +``` + ### 使用 Koyeb 一键部署 1. **准备工作** diff --git a/handler/chat_completions.go b/handler/chat_completions.go index d76bfd7..417f64e 100644 --- a/handler/chat_completions.go +++ b/handler/chat_completions.go @@ -17,24 +17,34 @@ import ( "github.com/kkkunny/HuggingChatAPI/internal/api" "github.com/kkkunny/HuggingChatAPI/internal/config" - "github.com/kkkunny/HuggingChatAPI/internal/consts" ) func ChatCompletions(w http.ResponseWriter, r *http.Request) { token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - cli := api.NewAPI(consts.HuggingChatDomain, token, config.Proxy) + cli, err := api.NewAPI(config.HuggingChatDomain, token) + if err != nil { + config.Logger.Error(err) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + err = cli.RefreshCookie(r.Context()) + if err != nil { + config.Logger.Error(err) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } var req openai.ChatCompletionRequest body, err := io.ReadAll(r.Body) if err != nil { config.Logger.Error(err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } err = json.Unmarshal(body, &req) if err != nil { config.Logger.Error(err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } @@ -128,7 +138,7 @@ func chatCompletionsNoStream(w http.ResponseWriter, msgID string, convInfo *api. Type: openai.ChatMessagePartTypeImageURL, ImageURL: &openai.ChatMessageImageURL{ Detail: openai.ImageURLDetailAuto, - URL: fmt.Sprintf("%s/chat/conversation/%s/output/%s", consts.HuggingChatDomain, convInfo.ConversationID, *msg.SHA), + URL: fmt.Sprintf("%s/chat/conversation/%s/output/%s", config.HuggingChatDomain, convInfo.ConversationID, *msg.SHA), }, }) } diff --git a/handler/list_models.go b/handler/list_models.go index 84e4655..7c43d8f 100644 --- a/handler/list_models.go +++ b/handler/list_models.go @@ -11,12 +11,23 @@ import ( "github.com/kkkunny/HuggingChatAPI/internal/api" "github.com/kkkunny/HuggingChatAPI/internal/config" - "github.com/kkkunny/HuggingChatAPI/internal/consts" ) func ListModels(w http.ResponseWriter, r *http.Request) { token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - cli := api.NewAPI(consts.HuggingChatDomain, token, config.Proxy) + cli, err := api.NewAPI(config.HuggingChatDomain, token) + if err != nil { + config.Logger.Error(err) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + err = cli.RefreshCookie(r.Context()) + if err != nil { + config.Logger.Error(err) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + models, err := cli.ListModels(r.Context()) if err != nil { config.Logger.Error(err) diff --git a/internal/api/api.go b/internal/api/api.go index 2b33269..11d6d04 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -3,12 +3,12 @@ package api import ( "bufio" "context" + "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" - "net/url" "regexp" "strings" "time" @@ -18,18 +18,77 @@ import ( ) type Api struct { - domain string - client *req.Client + domain string + client *req.Client + cookieMgr cookieMgr } -func NewAPI(domain string, token string, proxy func(*http.Request) (*url.URL, error)) *Api { - return &Api{ +func NewAPI(domain string, token string) (*Api, error) { + api := &Api{ domain: domain, - client: req.C(). - SetProxy(proxy). - SetCommonCookies(&http.Cookie{Name: "hf-chat", Value: token}). - SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0"), + client: globalClient.Clone(). + SetCommonHeader("origin", domain), } + err := api.SetToken(token) + return api, err +} + +func (api *Api) SetToken(token string) error { + account, err := base64.StdEncoding.DecodeString(token) + if err == nil { + res := regexp.MustCompile(`username=(.+?)&password=(.+)`).FindStringSubmatch(string(account)) + if len(res) != 3 { + return errors.New("invalid token") + } + api.cookieMgr = newAccountCookieMgr(res[1], res[2]) + return nil + } + api.cookieMgr = newTokenCookieMgr(token) + return nil +} + +func (api *Api) RefreshCookie(ctx context.Context) error { + cookies, err := api.cookieMgr.Cookies(ctx) + if err != nil { + return err + } + api.client.ClearCookies().SetCommonCookies(cookies...) + + isLogin, err := api.CheckLogin(ctx) + if err != nil { + return err + } else if isLogin { + return nil + } + + cookies, err = api.cookieMgr.Refresh(ctx) + if err != nil { + return err + } + api.client.ClearCookies().SetCommonCookies(cookies...) + + isLogin, err = api.CheckLogin(ctx) + if err != nil { + return err + } else if isLogin { + return nil + } + return errors.New("login failed") +} + +func (api *Api) CheckLogin(ctx context.Context) (bool, error) { + httpResp, err := api.client.R(). + SetContext(ctx). + SetHeaders(map[string]string{ + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + }). + Get(fmt.Sprintf("%s/chat/", api.domain)) + if err != nil { + return false, err + } else if httpResp.GetStatusCode() != http.StatusOK { + return false, fmt.Errorf("http error: code=%d, status=%s", httpResp.GetStatusCode(), httpResp.GetStatus()) + } + return !strings.Contains(httpResp.String(), "action=\"/chat/login\""), nil } type ModelInfo struct { @@ -41,11 +100,10 @@ type ModelInfo struct { } func (api *Api) ListModels(ctx context.Context) (resp []*ModelInfo, err error) { - urlStr := fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain) httpResp, err := api.client.R(). SetContext(ctx). SetSuccessResult(make(map[string]any)). - Get(urlStr) + Get(fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain)) if err != nil { return nil, err } else if httpResp.GetStatusCode() != http.StatusOK { @@ -107,12 +165,11 @@ type CreateConversationResponse struct { } func (api *Api) CreateConversation(ctx context.Context, req *CreateConversationRequest) (*CreateConversationResponse, error) { - urlStr := fmt.Sprintf("%s/chat/conversation", api.domain) httpResp, err := api.client.R(). SetContext(ctx). SetBodyJsonMarshal(req). SetSuccessResult(CreateConversationResponse{}). - Post(urlStr) + Post(fmt.Sprintf("%s/chat/conversation", api.domain)) if err != nil { return nil, err } else if httpResp.GetStatusCode() != http.StatusOK { @@ -127,10 +184,9 @@ type DeleteConversationRequest struct { } func (api *Api) DeleteConversation(ctx context.Context, req *DeleteConversationRequest) error { - urlStr := fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID) httpResp, err := api.client.R(). SetContext(ctx). - Delete(urlStr) + Delete(fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID)) if err != nil { return err } else if httpResp.GetStatusCode() != http.StatusOK { @@ -147,10 +203,9 @@ type SimpleConversationInfo struct { } func (api *Api) ListConversations(ctx context.Context) (resp []*SimpleConversationInfo, err error) { - urlStr := fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain) httpResp, err := api.client.R(). SetContext(ctx). - Get(urlStr) + Get(fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain)) if err != nil { return nil, err } else if httpResp.GetStatusCode() != http.StatusOK { @@ -225,11 +280,10 @@ type Message struct { } func (api *Api) ConversationInfo(ctx context.Context, req *ConversationInfoRequest) (resp *ConversationInfoResponse, err error) { - urlStr := fmt.Sprintf("%s/chat/conversation/%s/__data.json?x-sveltekit-invalidated=01", api.domain, req.ConversationID) httpResp, err := api.client.R(). SetContext(ctx). SetSuccessResult(make(map[string]any)). - Get(urlStr) + Get(fmt.Sprintf("%s/chat/conversation/%s/__data.json?x-sveltekit-invalidated=01", api.domain, req.ConversationID)) if err != nil { return nil, err } else if httpResp.GetStatusCode() != http.StatusOK { @@ -374,7 +428,7 @@ func (api *Api) ChatConversation(ctx context.Context, req *ChatConversationReque if err != nil { return nil, err } - urlStr := fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID) + resp, err := api.client.R(). SetContext(ctx). SetHeaders(map[string]string{ @@ -391,7 +445,7 @@ func (api *Api) ChatConversation(ctx context.Context, req *ChatConversationReque }). SetFormData(map[string]string{"data": string(reqBody)}). DisableAutoReadResponse(). - Post(urlStr) + Post(fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID)) if err != nil { return nil, err } else if resp.GetStatusCode() != http.StatusOK { diff --git a/internal/api/cookie_cache.go b/internal/api/cookie_cache.go new file mode 100644 index 0000000..bd88bda --- /dev/null +++ b/internal/api/cookie_cache.go @@ -0,0 +1,66 @@ +package api + +import ( + "encoding/json" + "net/http" + "os" + "path/filepath" + + stlerr "github.com/kkkunny/stl/error" + + "github.com/kkkunny/HuggingChatAPI/internal/config" +) + +var globalCookieCache *cookieCache + +func init() { + globalCookieCache = newCookieCache() + stlerr.Must(globalCookieCache.Load()) +} + +type cookieCache struct { + data map[string][]*http.Cookie +} + +func newCookieCache() *cookieCache { + return &cookieCache{data: make(map[string][]*http.Cookie)} +} + +func (cache *cookieCache) Load() error { + data, err := os.ReadFile(config.CookieCachePath) + if err != nil && os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + err = json.Unmarshal(data, &cache.data) + return err +} + +func (cache *cookieCache) Save() error { + err := os.MkdirAll(filepath.Dir(config.CookieCachePath), 0750) + if err != nil { + return err + } + file, err := os.OpenFile(config.CookieCachePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + return err + } + defer file.Close() + + data, err := json.MarshalIndent(cache.data, "", " ") + if err != nil { + return err + } + _, err = file.Write(data) + return err +} + +func (cache *cookieCache) Get(usr string) []*http.Cookie { + return cache.data[usr] +} + +func (cache *cookieCache) Set(usr string, cookies []*http.Cookie) error { + cache.data[usr] = cookies + return cache.Save() +} diff --git a/internal/api/cookie_mgr.go b/internal/api/cookie_mgr.go new file mode 100644 index 0000000..65101c7 --- /dev/null +++ b/internal/api/cookie_mgr.go @@ -0,0 +1,61 @@ +package api + +import ( + "context" + "errors" + "net/http" + + stlslices "github.com/kkkunny/stl/container/slices" +) + +type cookieMgr interface { + Cookies(ctx context.Context) ([]*http.Cookie, error) + Refresh(ctx context.Context) ([]*http.Cookie, error) +} + +type tokenCookieMgr struct { + token string +} + +func newTokenCookieMgr(token string) *tokenCookieMgr { + return &tokenCookieMgr{token: token} +} + +func (mgr *tokenCookieMgr) Cookies(_ context.Context) ([]*http.Cookie, error) { + return []*http.Cookie{{Name: "hf-chat", Value: mgr.token}}, nil +} + +func (mgr *tokenCookieMgr) Refresh(_ context.Context) ([]*http.Cookie, error) { + return nil, errors.New("can not refresh token cookie") +} + +type accountCookieMgr struct { + username string + password string +} + +func newAccountCookieMgr(usr, pwd string) *accountCookieMgr { + return &accountCookieMgr{ + username: usr, + password: pwd, + } +} + +func (mgr *accountCookieMgr) Cookies(ctx context.Context) ([]*http.Cookie, error) { + cookies := globalCookieCache.Get(mgr.username) + if len(stlslices.Filter(cookies, func(_ int, cookie *http.Cookie) bool { + return cookie.Name == "hf-chat" + })) == 0 { + return mgr.Refresh(ctx) + } + return globalCookieCache.Get(mgr.username), nil +} + +func (mgr *accountCookieMgr) Refresh(ctx context.Context) ([]*http.Cookie, error) { + cookies, err := Login(ctx, mgr.username, mgr.password) + if err != nil { + return nil, err + } + err = globalCookieCache.Set(mgr.username, cookies) + return cookies, err +} diff --git a/internal/api/init.go b/internal/api/init.go new file mode 100644 index 0000000..b3083fb --- /dev/null +++ b/internal/api/init.go @@ -0,0 +1,16 @@ +package api + +import ( + "github.com/imroc/req/v3" + + "github.com/kkkunny/HuggingChatAPI/internal/config" +) + +var globalClient *req.Client + +func init() { + globalClient = req.C(). + SetProxy(config.Proxy). + SetRedirectPolicy(req.NoRedirectPolicy()). + SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0") +} diff --git a/internal/api/login.go b/internal/api/login.go new file mode 100644 index 0000000..e08cffd --- /dev/null +++ b/internal/api/login.go @@ -0,0 +1,142 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/imroc/req/v3" + "golang.org/x/exp/maps" + + "github.com/kkkunny/HuggingChatAPI/internal/config" +) + +// Login 登录 +func Login(ctx context.Context, username string, password string) ([]*http.Cookie, error) { + cli := globalClient.Clone() + cli.SetCommonHeader("origin", config.HuggingChatDomain) + + loginResp, err := login(ctx, cli, &loginRequest{ + Username: username, + Password: password, + }) + if err != nil { + return nil, err + } + cli.SetCommonCookies(loginResp.Cookies...) + + chatLoginResp, err := chatLogin(ctx, cli) + if err != nil { + return nil, err + } + cli.SetCommonCookies(chatLoginResp.Cookies...) + + authorizeOauthResp, err := authorizeOauth(ctx, cli, chatLoginResp.Location.String()) + if err != nil { + return nil, err + } + + loginCallbackResp, err := loginCallback(ctx, cli, authorizeOauthResp.Location.String()) + if err != nil { + return nil, err + } + cli.SetCommonCookies(loginCallbackResp.Cookies...) + + cookies := make(map[string]*http.Cookie, len(cli.Cookies)) + for _, cookie := range cli.Cookies { + cookies[cookie.Name] = cookie + } + return maps.Values(cookies), nil +} + +type loginRequest struct { + Location string + Username string + Password string +} + +type loginResponse struct { + Cookies []*http.Cookie +} + +func login(ctx context.Context, cli *req.Client, req *loginRequest) (*loginResponse, error) { + resp, err := cli.R(). + SetContext(ctx). + SetContentType("application/x-www-form-urlencoded"). + SetBodyString(fmt.Sprintf("username=%s&password=%s", req.Username, req.Password)). + Post(fmt.Sprintf("%s/login", config.HuggingChatDomain)) + if err != nil { + return nil, err + } else if resp.GetStatusCode() != http.StatusFound { + return nil, fmt.Errorf("http error: code=%d, status=%s", resp.GetStatusCode(), resp.GetStatus()) + } + return &loginResponse{Cookies: resp.Cookies()}, nil +} + +type chatLoginResponse struct { + Location *url.URL + Cookies []*http.Cookie +} + +func chatLogin(ctx context.Context, cli *req.Client) (*chatLoginResponse, error) { + resp, err := cli.R(). + SetContext(ctx). + SetContentType("application/x-www-form-urlencoded"). + SetHeaders(map[string]string{ + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + }). + Post(fmt.Sprintf("%s/chat/login", config.HuggingChatDomain)) + if err != nil { + return nil, err + } else if resp.GetStatusCode() != http.StatusSeeOther { + return nil, fmt.Errorf("http error: code=%d, status=%s", resp.GetStatusCode(), resp.GetStatus()) + } + location, err := resp.Location() + if err != nil { + return nil, err + } + return &chatLoginResponse{ + Location: location, + Cookies: resp.Cookies(), + }, nil +} + +type authorizeOauthResponse struct { + Location *url.URL +} + +func authorizeOauth(ctx context.Context, cli *req.Client, urlStr string) (*authorizeOauthResponse, error) { + resp, err := cli.R(). + SetContext(ctx). + SetHeaders(map[string]string{ + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + }). + Get(urlStr) + if err != nil { + return nil, err + } else if resp.GetStatusCode() != http.StatusSeeOther { + return nil, fmt.Errorf("http error: code=%d, status=%s", resp.GetStatusCode(), resp.GetStatus()) + } + location, err := resp.Location() + if err != nil { + return nil, err + } + return &authorizeOauthResponse{Location: location}, nil +} + +type loginCallbackResponse struct { + Cookies []*http.Cookie +} + +func loginCallback(ctx context.Context, cli *req.Client, urlStr string) (*loginCallbackResponse, error) { + resp, err := cli.R(). + SetContext(ctx). + Get(urlStr) + if err != nil { + return nil, err + } else if resp.GetStatusCode() != http.StatusFound { + return nil, fmt.Errorf("http error: code=%d, status=%s", resp.GetStatusCode(), resp.GetStatus()) + } + return &loginCallbackResponse{Cookies: resp.Cookies()}, nil +} diff --git a/internal/consts/const.go b/internal/config/domain.go similarity index 77% rename from internal/consts/const.go rename to internal/config/domain.go index eecde1b..6bb1303 100644 --- a/internal/consts/const.go +++ b/internal/config/domain.go @@ -1,3 +1,3 @@ -package consts +package config const HuggingChatDomain = "https://huggingface.co" diff --git a/internal/config/file.go b/internal/config/file.go new file mode 100644 index 0000000..3273f22 --- /dev/null +++ b/internal/config/file.go @@ -0,0 +1,3 @@ +package config + +const CookieCachePath = "config/cookies.json" diff --git a/internal/config/proxy.go b/internal/config/proxy.go index 67258f2..872b8f9 100644 --- a/internal/config/proxy.go +++ b/internal/config/proxy.go @@ -4,19 +4,18 @@ import ( "net/http" "net/url" "os" + + stlerr "github.com/kkkunny/stl/error" + stlval "github.com/kkkunny/stl/value" ) -var Proxy = func() func(*http.Request) (*url.URL, error) { - proxy := os.Getenv("https_proxy") - if proxy == "" { - proxy = os.Getenv("HTTPS_PROXY") - if proxy == "" { - return nil - } - } - proxyUrl, err := url.Parse(proxy) - if err != nil { - panic(err) +var Proxy func(*http.Request) (*url.URL, error) + +func init() { + proxyStr := stlval.Ternary(os.Getenv("https_proxy") != "", os.Getenv("https_proxy"), os.Getenv("HTTPS_PROXY")) + if proxyStr == "" { + return } - return http.ProxyURL(proxyUrl) -}() + proxy := stlerr.MustWith(url.Parse(proxyStr)) + Proxy = http.ProxyURL(proxy) +}