Skip to content

Commit

Permalink
支持登录 (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkkunny authored Nov 9, 2024
1 parent f046823 commit 64de7b2
Show file tree
Hide file tree
Showing 12 changed files with 418 additions and 47 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ RUN apk --no-cache add tzdata && \
ENV TZ Asia/Shanghai
RUN apk add --no-cache ca-certificates && update-ca-certificates
WORKDIR /app
RUN mkdir config
COPY --from=builder /app/bin/* /app
EXPOSE 80
ENTRYPOINT ["/app/server"]
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

## 使用说明

1. **访问 HuggingChat 官网**
进入 [HuggingChat官网](https://huggingface.co/chat) 并登录。
1. **获取 Authorization**
以下两种方式任选其一!

2. **获取 Authorization**
从浏览器的 Cookie 中取出 `hf-chat` 的值,作为 Authorization。
+ 使用base64编码`username={你的账号名}&password={你的密码}`,假设你的账号名为usr,密码为pwd,则`username=usr&password=pwd`进行base64编码后结果为`dXNlcm5hbWU9dXNyJnBhc3N3b3JkPXB3ZA==`
+ 进入 [HuggingChat官网](https://huggingface.co/chat) 并登录,从浏览器的 Cookie 中取出 `hf-chat` 的值,形如`cc43f26e-142b-409f-b228-68316s5x30a9`

3. **会话创建注意事项**
将上面任意一种方式获得的值填入Authorization

2. **!!注意事项!!会话创建**
由于调用创建会话的接口创建出的会话总是呈不可用状态,因此需要提前为每个模型创建好会话。

## 调用说明
Expand Down Expand Up @@ -98,6 +100,12 @@ curl -X POST "http://localhost:5695/v1/chat/completions" \
docker run -d --name HuggingChat -p 5695:80 kkkunny/hugging-chat-api:latest
```

强烈建议将/app/config映射到本地路径

```bash
docker run -d --name HuggingChat -p 5695:80 -v YOUR_PATH:/app/config kkkunny/hugging-chat-api:latest
```

### 使用 Koyeb 一键部署

1. **准备工作**
Expand Down
20 changes: 15 additions & 5 deletions handler/chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,34 @@ import (

"github.com/kkkunny/HuggingChatAPI/internal/api"
"github.com/kkkunny/HuggingChatAPI/internal/config"
"github.com/kkkunny/HuggingChatAPI/internal/consts"
)

func ChatCompletions(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
cli := api.NewAPI(consts.HuggingChatDomain, token, config.Proxy)
cli, err := api.NewAPI(config.HuggingChatDomain, token)
if err != nil {
config.Logger.Error(err)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
err = cli.RefreshCookie(r.Context())
if err != nil {
config.Logger.Error(err)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

var req openai.ChatCompletionRequest
body, err := io.ReadAll(r.Body)
if err != nil {
config.Logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
err = json.Unmarshal(body, &req)
if err != nil {
config.Logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}

Expand Down Expand Up @@ -128,7 +138,7 @@ func chatCompletionsNoStream(w http.ResponseWriter, msgID string, convInfo *api.
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
Detail: openai.ImageURLDetailAuto,
URL: fmt.Sprintf("%s/chat/conversation/%s/output/%s", consts.HuggingChatDomain, convInfo.ConversationID, *msg.SHA),
URL: fmt.Sprintf("%s/chat/conversation/%s/output/%s", config.HuggingChatDomain, convInfo.ConversationID, *msg.SHA),
},
})
}
Expand Down
15 changes: 13 additions & 2 deletions handler/list_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,23 @@ import (

"github.com/kkkunny/HuggingChatAPI/internal/api"
"github.com/kkkunny/HuggingChatAPI/internal/config"
"github.com/kkkunny/HuggingChatAPI/internal/consts"
)

func ListModels(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
cli := api.NewAPI(consts.HuggingChatDomain, token, config.Proxy)
cli, err := api.NewAPI(config.HuggingChatDomain, token)
if err != nil {
config.Logger.Error(err)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
err = cli.RefreshCookie(r.Context())
if err != nil {
config.Logger.Error(err)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

models, err := cli.ListModels(r.Context())
if err != nil {
config.Logger.Error(err)
Expand Down
96 changes: 75 additions & 21 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package api
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
Expand All @@ -18,18 +18,77 @@ import (
)

type Api struct {
domain string
client *req.Client
domain string
client *req.Client
cookieMgr cookieMgr
}

func NewAPI(domain string, token string, proxy func(*http.Request) (*url.URL, error)) *Api {
return &Api{
func NewAPI(domain string, token string) (*Api, error) {
api := &Api{
domain: domain,
client: req.C().
SetProxy(proxy).
SetCommonCookies(&http.Cookie{Name: "hf-chat", Value: token}).
SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0"),
client: globalClient.Clone().
SetCommonHeader("origin", domain),
}
err := api.SetToken(token)
return api, err
}

func (api *Api) SetToken(token string) error {
account, err := base64.StdEncoding.DecodeString(token)
if err == nil {
res := regexp.MustCompile(`username=(.+?)&password=(.+)`).FindStringSubmatch(string(account))
if len(res) != 3 {
return errors.New("invalid token")
}
api.cookieMgr = newAccountCookieMgr(res[1], res[2])
return nil
}
api.cookieMgr = newTokenCookieMgr(token)
return nil
}

func (api *Api) RefreshCookie(ctx context.Context) error {
cookies, err := api.cookieMgr.Cookies(ctx)
if err != nil {
return err
}
api.client.ClearCookies().SetCommonCookies(cookies...)

isLogin, err := api.CheckLogin(ctx)
if err != nil {
return err
} else if isLogin {
return nil
}

cookies, err = api.cookieMgr.Refresh(ctx)
if err != nil {
return err
}
api.client.ClearCookies().SetCommonCookies(cookies...)

isLogin, err = api.CheckLogin(ctx)
if err != nil {
return err
} else if isLogin {
return nil
}
return errors.New("login failed")
}

func (api *Api) CheckLogin(ctx context.Context) (bool, error) {
httpResp, err := api.client.R().
SetContext(ctx).
SetHeaders(map[string]string{
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
}).
Get(fmt.Sprintf("%s/chat/", api.domain))
if err != nil {
return false, err
} else if httpResp.GetStatusCode() != http.StatusOK {
return false, fmt.Errorf("http error: code=%d, status=%s", httpResp.GetStatusCode(), httpResp.GetStatus())
}
return !strings.Contains(httpResp.String(), "action=\"/chat/login\""), nil
}

type ModelInfo struct {
Expand All @@ -41,11 +100,10 @@ type ModelInfo struct {
}

func (api *Api) ListModels(ctx context.Context) (resp []*ModelInfo, err error) {
urlStr := fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain)
httpResp, err := api.client.R().
SetContext(ctx).
SetSuccessResult(make(map[string]any)).
Get(urlStr)
Get(fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain))
if err != nil {
return nil, err
} else if httpResp.GetStatusCode() != http.StatusOK {
Expand Down Expand Up @@ -107,12 +165,11 @@ type CreateConversationResponse struct {
}

func (api *Api) CreateConversation(ctx context.Context, req *CreateConversationRequest) (*CreateConversationResponse, error) {
urlStr := fmt.Sprintf("%s/chat/conversation", api.domain)
httpResp, err := api.client.R().
SetContext(ctx).
SetBodyJsonMarshal(req).
SetSuccessResult(CreateConversationResponse{}).
Post(urlStr)
Post(fmt.Sprintf("%s/chat/conversation", api.domain))
if err != nil {
return nil, err
} else if httpResp.GetStatusCode() != http.StatusOK {
Expand All @@ -127,10 +184,9 @@ type DeleteConversationRequest struct {
}

func (api *Api) DeleteConversation(ctx context.Context, req *DeleteConversationRequest) error {
urlStr := fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID)
httpResp, err := api.client.R().
SetContext(ctx).
Delete(urlStr)
Delete(fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID))
if err != nil {
return err
} else if httpResp.GetStatusCode() != http.StatusOK {
Expand All @@ -147,10 +203,9 @@ type SimpleConversationInfo struct {
}

func (api *Api) ListConversations(ctx context.Context) (resp []*SimpleConversationInfo, err error) {
urlStr := fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain)
httpResp, err := api.client.R().
SetContext(ctx).
Get(urlStr)
Get(fmt.Sprintf("%s/chat/models/__data.json?x-sveltekit-invalidated=10", api.domain))
if err != nil {
return nil, err
} else if httpResp.GetStatusCode() != http.StatusOK {
Expand Down Expand Up @@ -225,11 +280,10 @@ type Message struct {
}

func (api *Api) ConversationInfo(ctx context.Context, req *ConversationInfoRequest) (resp *ConversationInfoResponse, err error) {
urlStr := fmt.Sprintf("%s/chat/conversation/%s/__data.json?x-sveltekit-invalidated=01", api.domain, req.ConversationID)
httpResp, err := api.client.R().
SetContext(ctx).
SetSuccessResult(make(map[string]any)).
Get(urlStr)
Get(fmt.Sprintf("%s/chat/conversation/%s/__data.json?x-sveltekit-invalidated=01", api.domain, req.ConversationID))
if err != nil {
return nil, err
} else if httpResp.GetStatusCode() != http.StatusOK {
Expand Down Expand Up @@ -374,7 +428,7 @@ func (api *Api) ChatConversation(ctx context.Context, req *ChatConversationReque
if err != nil {
return nil, err
}
urlStr := fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID)

resp, err := api.client.R().
SetContext(ctx).
SetHeaders(map[string]string{
Expand All @@ -391,7 +445,7 @@ func (api *Api) ChatConversation(ctx context.Context, req *ChatConversationReque
}).
SetFormData(map[string]string{"data": string(reqBody)}).
DisableAutoReadResponse().
Post(urlStr)
Post(fmt.Sprintf("%s/chat/conversation/%s", api.domain, req.ConversationID))
if err != nil {
return nil, err
} else if resp.GetStatusCode() != http.StatusOK {
Expand Down
66 changes: 66 additions & 0 deletions internal/api/cookie_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package api

import (
"encoding/json"
"net/http"
"os"
"path/filepath"

stlerr "github.com/kkkunny/stl/error"

"github.com/kkkunny/HuggingChatAPI/internal/config"
)

var globalCookieCache *cookieCache

func init() {
globalCookieCache = newCookieCache()
stlerr.Must(globalCookieCache.Load())
}

type cookieCache struct {
data map[string][]*http.Cookie
}

func newCookieCache() *cookieCache {
return &cookieCache{data: make(map[string][]*http.Cookie)}
}

func (cache *cookieCache) Load() error {
data, err := os.ReadFile(config.CookieCachePath)
if err != nil && os.IsNotExist(err) {
return nil
} else if err != nil {
return err
}
err = json.Unmarshal(data, &cache.data)
return err
}

func (cache *cookieCache) Save() error {
err := os.MkdirAll(filepath.Dir(config.CookieCachePath), 0750)
if err != nil {
return err
}
file, err := os.OpenFile(config.CookieCachePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
if err != nil {
return err
}
defer file.Close()

data, err := json.MarshalIndent(cache.data, "", " ")
if err != nil {
return err
}
_, err = file.Write(data)
return err
}

func (cache *cookieCache) Get(usr string) []*http.Cookie {
return cache.data[usr]
}

func (cache *cookieCache) Set(usr string, cookies []*http.Cookie) error {
cache.data[usr] = cookies
return cache.Save()
}
Loading

0 comments on commit 64de7b2

Please sign in to comment.