generated from songquanpeng/gin-template
-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/MotorBottle/one-api into main
- Loading branch information
Showing
40 changed files
with
865 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
package auth | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"github.com/gin-contrib/sessions" | ||
"github.com/gin-gonic/gin" | ||
"github.com/songquanpeng/one-api/common/config" | ||
"github.com/songquanpeng/one-api/common/logger" | ||
"github.com/songquanpeng/one-api/controller" | ||
"github.com/songquanpeng/one-api/model" | ||
"net/http" | ||
"strconv" | ||
"time" | ||
) | ||
|
||
type OidcResponse struct { | ||
AccessToken string `json:"access_token"` | ||
IDToken string `json:"id_token"` | ||
RefreshToken string `json:"refresh_token"` | ||
TokenType string `json:"token_type"` | ||
ExpiresIn int `json:"expires_in"` | ||
Scope string `json:"scope"` | ||
} | ||
|
||
type OidcUser struct { | ||
OpenID string `json:"sub"` | ||
Email string `json:"email"` | ||
Name string `json:"name"` | ||
PreferredUsername string `json:"preferred_username"` | ||
Picture string `json:"picture"` | ||
} | ||
|
||
func getOidcUserInfoByCode(code string) (*OidcUser, error) { | ||
if code == "" { | ||
return nil, errors.New("无效的参数") | ||
} | ||
values := map[string]string{ | ||
"client_id": config.OidcClientId, | ||
"client_secret": config.OidcClientSecret, | ||
"code": code, | ||
"grant_type": "authorization_code", | ||
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), | ||
} | ||
jsonData, err := json.Marshal(values) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Accept", "application/json") | ||
client := http.Client{ | ||
Timeout: 5 * time.Second, | ||
} | ||
res, err := client.Do(req) | ||
if err != nil { | ||
logger.SysLog(err.Error()) | ||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||
} | ||
defer res.Body.Close() | ||
var oidcResponse OidcResponse | ||
err = json.NewDecoder(res.Body).Decode(&oidcResponse) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) | ||
res2, err := client.Do(req) | ||
if err != nil { | ||
logger.SysLog(err.Error()) | ||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||
} | ||
var oidcUser OidcUser | ||
err = json.NewDecoder(res2.Body).Decode(&oidcUser) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &oidcUser, nil | ||
} | ||
|
||
func OidcAuth(c *gin.Context) { | ||
session := sessions.Default(c) | ||
state := c.Query("state") | ||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||
c.JSON(http.StatusForbidden, gin.H{ | ||
"success": false, | ||
"message": "state is empty or not same", | ||
}) | ||
return | ||
} | ||
username := session.Get("username") | ||
if username != nil { | ||
OidcBind(c) | ||
return | ||
} | ||
if !config.OidcEnabled { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": "管理员未开启通过 OIDC 登录以及注册", | ||
}) | ||
return | ||
} | ||
code := c.Query("code") | ||
oidcUser, err := getOidcUserInfoByCode(code) | ||
if err != nil { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": err.Error(), | ||
}) | ||
return | ||
} | ||
user := model.User{ | ||
OidcId: oidcUser.OpenID, | ||
} | ||
if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||
err := user.FillUserByOidcId() | ||
if err != nil { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": err.Error(), | ||
}) | ||
return | ||
} | ||
} else { | ||
if config.RegisterEnabled { | ||
user.Email = oidcUser.Email | ||
if oidcUser.PreferredUsername != "" { | ||
user.Username = oidcUser.PreferredUsername | ||
} else { | ||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) | ||
} | ||
if oidcUser.Name != "" { | ||
user.DisplayName = oidcUser.Name | ||
} else { | ||
user.DisplayName = "OIDC User" | ||
} | ||
err := user.Insert(0) | ||
if err != nil { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": err.Error(), | ||
}) | ||
return | ||
} | ||
} else { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": "管理员关闭了新用户注册", | ||
}) | ||
return | ||
} | ||
} | ||
|
||
if user.Status != model.UserStatusEnabled { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"message": "用户已被封禁", | ||
"success": false, | ||
}) | ||
return | ||
} | ||
controller.SetupLogin(&user, c) | ||
} | ||
|
||
func OidcBind(c *gin.Context) { | ||
if !config.OidcEnabled { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": "管理员未开启通过 OIDC 登录以及注册", | ||
}) | ||
return | ||
} | ||
code := c.Query("code") | ||
oidcUser, err := getOidcUserInfoByCode(code) | ||
if err != nil { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": err.Error(), | ||
}) | ||
return | ||
} | ||
user := model.User{ | ||
OidcId: oidcUser.OpenID, | ||
} | ||
if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": "该 OIDC 账户已被绑定", | ||
}) | ||
return | ||
} | ||
session := sessions.Default(c) | ||
id := session.Get("id") | ||
// id := c.GetInt("id") // critical bug! | ||
user.Id = id.(int) | ||
err = user.FillUserById() | ||
if err != nil { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": err.Error(), | ||
}) | ||
return | ||
} | ||
user.OidcId = oidcUser.OpenID | ||
err = user.Update(false) | ||
if err != nil { | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": false, | ||
"message": err.Error(), | ||
}) | ||
return | ||
} | ||
c.JSON(http.StatusOK, gin.H{ | ||
"success": true, | ||
"message": "bind", | ||
}) | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.