-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoauth.go
117 lines (96 loc) · 3.16 KB
/
oauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package ssso
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"github.com/google/uuid"
"github.com/pilab-dev/shadow-sso/api"
"github.com/rs/zerolog/log"
)
// GenerateAuthCode generates a new authorization code for OAuth2 authorization code flow.
// It creates a secure random code and stores it with the provided client details.
func (s *OAuthService) GenerateAuthCode(ctx context.Context, clientID, redirectURI, scope string) (string, error) {
// Generate secure random bytes for auth code
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
code := base64.StdEncoding.EncodeToString(b)
// Create auth code record
authCode := &AuthCode{
Code: code,
ClientID: clientID,
RedirectURI: redirectURI,
Scope: scope,
ExpiresAt: time.Now().Add(10 * time.Minute),
CreatedAt: time.Now(),
}
if err := s.oauthRepo.SaveAuthCode(ctx, authCode); err != nil {
return "", fmt.Errorf("failed to save auth code: %w", err)
}
return code, nil
}
// GenerateTokens exchanges an authorization code for access and refresh tokens.
// It validates the code and client ID before generating new tokens.
func (s *OAuthService) GenerateTokens(ctx context.Context, code, clientID string) (*api.TokenResponse, error) {
// Get and validate the stored auth code
authCode, err := s.oauthRepo.GetAuthCode(ctx, code)
if err != nil {
return nil, fmt.Errorf("invalid auth code: %w", err)
}
if authCode.ClientID != clientID {
return nil, fmt.Errorf("client ID mismatch")
}
if authCode.Used {
return nil, fmt.Errorf("auth code already used")
}
if time.Now().After(authCode.ExpiresAt) {
return nil, fmt.Errorf("auth code expired")
}
// Generate new tokens
accessToken := uuid.NewString()
refreshToken := uuid.NewString()
expiresAt := time.Now().Add(time.Hour)
// Create token record
token := &Token{
ID: uuid.NewString(),
TokenType: "access_token",
TokenValue: accessToken,
ClientID: clientID,
Scope: authCode.Scope,
ExpiresAt: expiresAt,
CreatedAt: time.Now(),
LastUsedAt: time.Now(),
}
_ = token
// if err := s.oauthRepo.StoreToken(ctx, token); err != nil {
// return nil, fmt.Errorf("failed to store token: %w", err)
// }
// Mark auth code as used
if err := s.oauthRepo.MarkAuthCodeAsUsed(ctx, code); err != nil {
return nil, fmt.Errorf("failed to mark auth code as used: %w", err)
}
return &api.TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: refreshToken,
}, nil
}
// GetUserInfo retrieves user information for a valid access token.
func (s *OAuthService) GetUserInfo(ctx context.Context, token string) (map[string]interface{}, error) {
userID, err := s.tokenService.ValidateAccessToken(ctx, token)
if err != nil {
return nil, fmt.Errorf("invalid access token: %w", err)
}
_ = userID
log.Error().Msg("GetUserInfo not implemented")
// return s.userRepo.GetUserInfo(ctx, userID)
return nil, nil
}
// RevokeToken revokes an access token.
func (s *OAuthService) RevokeToken(ctx context.Context, token string) error {
return s.tokenService.RevokeToken(ctx, token)
}