Skip to content

Commit

Permalink
feat!: Allow passing client data through g8.data context key
Browse files Browse the repository at this point in the history
This adds the Data field to the Client struct
  • Loading branch information
TwiN committed Jan 12, 2025
1 parent cc6ae9d commit 9a75e85
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 58 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# g8

![test](https://github.com/TwiN/g8/workflows/test/badge.svg?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/TwiN/g8)](https://goreportcard.com/report/github.com/TwiN/g8/v2)
[![Go Report Card](https://goreportcard.com/badge/github.com/TwiN/g8)](https://goreportcard.com/report/github.com/TwiN/g8/v3)
[![codecov](https://codecov.io/gh/TwiN/g8/branch/master/graph/badge.svg)](https://codecov.io/gh/TwiN/g8)
[![Go version](https://img.shields.io/github/go-mod/go-version/TwiN/g8.svg)](https://github.com/TwiN/g8)
[![Go Reference](https://pkg.go.dev/badge/github.com/TwiN/g8.svg)](https://pkg.go.dev/github.com/TwiN/g8/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/TwiN/g8.svg)](https://pkg.go.dev/github.com/TwiN/g8/v3)
[![Follow TwiN](https://img.shields.io/github/followers/TwiN?label=Follow&style=social)](https://github.com/TwiN)

g8, pronounced gate, is a simple Go library for protecting HTTP handlers.
Expand All @@ -14,7 +14,7 @@ Tired of constantly re-implementing a security layer for each application? Me to

## Installation
```console
go get -u github.com/TwiN/g8/v2
go get -u github.com/TwiN/g8/v3
```


Expand Down Expand Up @@ -284,7 +284,7 @@ gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenE
package main

import (
g8 "github.com/TwiN/g8/v2"
g8 "github.com/TwiN/g8/v3"
)

type customCache struct {
Expand All @@ -309,7 +309,7 @@ func main() {
// has the user's token as well as the permissions granted to said user
user := database.GetUserByToken(token)
if user != nil {
return g8.NewClient(user.Token).WithPermissions(user.Permissions)
return g8.NewClient(user.Token).WithPermissions(user.Permissions).WithData(user.Data)
}
return nil
}
Expand Down
17 changes: 10 additions & 7 deletions authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,27 @@ func (authorizationService *AuthorizationService) WithClientProvider(provider *C
return authorizationService
}

// IsAuthorized checks whether a client with a given token exists and has the permissions required.
// Authorize checks whether a client with a given token exists and has the permissions required.
//
// If permissionsRequired is nil or empty and a client with the given token exists, said client will have access to all
// handlers that are not protected by a given permission.
func (authorizationService *AuthorizationService) IsAuthorized(token string, permissionsRequired []string) bool {
//
// Returns whether the client is authorized, and if true, the client that was authorized.
func (authorizationService *AuthorizationService) Authorize(token string, permissionsRequired []string) (client *Client, authorized bool) {
if len(token) == 0 {
return false
return nil, false
}
authorizationService.mutex.RLock()
client, _ := authorizationService.clients[token]
client, _ = authorizationService.clients[token]
authorizationService.mutex.RUnlock()
// If there's no clients with the given token directly stored in the AuthorizationService, fall back to the
// client provider, if there's one configured.
if client == nil && authorizationService.clientProvider != nil {
client = authorizationService.clientProvider.GetClientByToken(token)
}
if client != nil {
return client.HasPermissions(permissionsRequired)
if client != nil && client.HasPermissions(permissionsRequired) {
// If the client has the required permissions, return true and the client
return client, true
}
return false
return nil, false
}
54 changes: 27 additions & 27 deletions authorization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,105 +4,105 @@ import "testing"

func TestAuthorizationService_IsAuthorized(t *testing.T) {
authorizationService := NewAuthorizationService().WithToken("token")
if !authorizationService.IsAuthorized("token", nil) {
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
t.Error("should've returned true")
}
if authorizationService.IsAuthorized("bad-token", nil) {
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("token", []string{"admin"}) {
if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("", nil) {
if _, authorized := authorizationService.Authorize("", nil); authorized {
t.Error("should've returned false")
}
}

func TestAuthorizationService_IsAuthorizedWithPermissions(t *testing.T) {
authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"}))
if !authorizationService.IsAuthorized("token", nil) {
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("token", []string{"a"}) {
if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("token", []string{"b"}) {
if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("token", []string{"a", "b"}) {
if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized {
t.Error("should've returned true")
}
if authorizationService.IsAuthorized("token", []string{"c"}) {
if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("token", []string{"a", "c"}) {
if _, authorized := authorizationService.Authorize("token", []string{"a", "c"}); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("bad-token", nil) {
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("bad-token", []string{"a"}) {
if _, authorized := authorizationService.Authorize("bad-token", []string{"a"}); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("", []string{"a"}) {
if _, authorized := authorizationService.Authorize("", []string{"a"}); authorized {
t.Error("should've returned false")
}
}

func TestAuthorizationService_WithToken(t *testing.T) {
authorizationService := NewAuthorizationService().WithToken("token")
if !authorizationService.IsAuthorized("token", nil) {
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
t.Error("should've returned true")
}
if authorizationService.IsAuthorized("bad-token", nil) {
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("token", []string{"admin"}) {
if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized {
t.Error("should've returned false")
}
}

func TestAuthorizationService_WithTokens(t *testing.T) {
authorizationService := NewAuthorizationService().WithTokens([]string{"1", "2"})
if !authorizationService.IsAuthorized("1", nil) {
if _, authorized := authorizationService.Authorize("1", nil); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("2", nil) {
if _, authorized := authorizationService.Authorize("2", nil); !authorized {
t.Error("should've returned true")
}
if authorizationService.IsAuthorized("3", nil) {
if _, authorized := authorizationService.Authorize("3", nil); authorized {
t.Error("should've returned false")
}
}

func TestAuthorizationService_WithClient(t *testing.T) {
authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"}))
if !authorizationService.IsAuthorized("token", []string{"a", "b"}) {
if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("token", []string{"a"}) {
if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("token", []string{"b"}) {
if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized {
t.Error("should've returned true")
}
if authorizationService.IsAuthorized("token", []string{"c"}) {
if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized {
t.Error("should've returned false")
}
}

func TestAuthorizationService_WithClients(t *testing.T) {
authorizationService := NewAuthorizationService().WithClients([]*Client{NewClient("1").WithPermission("a"), NewClient("2").WithPermission("b")})
if !authorizationService.IsAuthorized("1", []string{"a"}) {
if _, authorized := authorizationService.Authorize("1", []string{"a"}); !authorized {
t.Error("should've returned true")
}
if !authorizationService.IsAuthorized("2", []string{"b"}) {
if _, authorized := authorizationService.Authorize("2", []string{"b"}); !authorized {
t.Error("should've returned true")
}
if authorizationService.IsAuthorized("1", []string{"b"}) {
if _, authorized := authorizationService.Authorize("1", []string{"b"}); authorized {
t.Error("should've returned false")
}
if authorizationService.IsAuthorized("2", []string{"a"}) {
if _, authorized := authorizationService.Authorize("2", []string{"a"}); authorized {
t.Error("should've returned false")
}
}
25 changes: 23 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ type Client struct {
// If you only wish to use Gate.Protect and Gate.ProtectFunc, you do not have to worry about this,
// since they're only used by Gate.ProtectWithPermissions and Gate.ProtectFuncWithPermissions
Permissions []string

// Data is a field that can be used to store any data you want to associate with the client.
Data any
}

// NewClient creates a Client with a given token
Expand All @@ -25,6 +28,18 @@ func NewClientWithPermissions(token string, permissions []string) *Client {
return NewClient(token).WithPermissions(permissions)
}

// NewClientWithData creates a Client with some data
// Equivalent to using NewClient and WithData
func NewClientWithData(token string, data any) *Client {
return NewClient(token).WithData(data)
}

// NewClientWithPermissionsAndData creates a Client with a slice of permissions and some data
// Equivalent to using NewClient, WithPermissions and WithData
func NewClientWithPermissionsAndData(token string, permissions []string, data any) *Client {
return NewClient(token).WithPermissions(permissions).WithData(data)
}

// WithPermissions adds a slice of permissions to a client
func (client *Client) WithPermissions(permissions []string) *Client {
client.Permissions = append(client.Permissions, permissions...)
Expand All @@ -37,8 +52,14 @@ func (client *Client) WithPermission(permission string) *Client {
return client
}

// WithData attaches data to a client
func (client *Client) WithData(data any) *Client {
client.Data = data
return client
}

// HasPermission checks whether a client has a given permission
func (client Client) HasPermission(permissionRequired string) bool {
func (client *Client) HasPermission(permissionRequired string) bool {
for _, permission := range client.Permissions {
if permissionRequired == permission {
return true
Expand All @@ -48,7 +69,7 @@ func (client Client) HasPermission(permissionRequired string) bool {
}

// HasPermissions checks whether a client has the all permissions passed
func (client Client) HasPermissions(permissionsRequired []string) bool {
func (client *Client) HasPermissions(permissionsRequired []string) bool {
for _, permissionRequired := range permissionsRequired {
if !client.HasPermission(permissionRequired) {
return false
Expand Down
38 changes: 38 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,41 @@ func TestClient_HasPermissions(t *testing.T) {
t.Errorf("client has permissions %s, therefore HasPermissions([a, b, c]) should've been false", client.Permissions)
}
}

func TestClient_WithData(t *testing.T) {
client := NewClient("token")
if client.Data != nil {
t.Error("expected client data to be nil")
}
client.WithData(5)
if client.Data != 5 {
t.Errorf("expected client data to be 5, got %d", client.Data)
}
client.WithData(map[string]string{"key": "value"})
if data, ok := client.Data.(map[string]string); !ok || data["key"] != "value" {
t.Errorf("expected client data to be map[string]string{key: value}, got %v", client.Data)
}
}

func TestNewClientWithData(t *testing.T) {
client := NewClientWithData("token", 5)
if client.Data != 5 {
t.Errorf("expected client data to be 5, got %d", client.Data)
}
}

func TestNewClientWithPermissionsAndData(t *testing.T) {
client := NewClientWithPermissionsAndData("token", []string{"a", "b"}, 5)
if client.Data != 5 {
t.Errorf("expected client data to be 5, got %d", client.Data)
}
if !client.HasPermission("a") {
t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions)
}
if !client.HasPermission("b") {
t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions)
}
if client.HasPermission("c") {
t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions)
}
}
4 changes: 3 additions & 1 deletion clientprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
var (
getClientByTokenFunc = func(token string) *Client {
if token == "valid-token" {
return &Client{Token: token}
return NewClient("valid-token").WithData("client-data")
}
return nil
}
Expand All @@ -21,6 +21,8 @@ func TestClientProvider_GetClientByToken(t *testing.T) {
provider := NewClientProvider(getClientByTokenFunc)
if client := provider.GetClientByToken("valid-token"); client == nil {
t.Error("should've returned a client")
} else if client.Data != "client-data" {
t.Error("expected client data to be 'client-data', got", client.Data)
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("should've returned nil")
Expand Down
13 changes: 10 additions & 3 deletions gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ const (
// DefaultTooManyRequestsResponseBody is the default response body returned if a request exceeded the allowed rate limit
DefaultTooManyRequestsResponseBody = "too many requests"

// TokenContextKey is the key used to store the token in the context.
// TokenContextKey is the key used to store the client's token in the context.
TokenContextKey = "g8.token"

// DataContextKey is the key used to store the client's data in the context.
DataContextKey = "g8.data"
)

// Gate is lock to the front door of your API, letting only those you allow through.
Expand Down Expand Up @@ -187,12 +190,16 @@ func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permi
}
if gate.authorizationService != nil {
token := gate.ExtractTokenFromRequest(request)
if !gate.authorizationService.IsAuthorized(token, permissions) {
if client, authorized := gate.authorizationService.Authorize(token, permissions); !authorized {
writer.WriteHeader(http.StatusUnauthorized)
_, _ = writer.Write(gate.unauthorizedResponseBody)
return
} else {
request = request.WithContext(context.WithValue(request.Context(), TokenContextKey, token))
if client != nil && client.Data != nil {
request = request.WithContext(context.WithValue(request.Context(), DataContextKey, client.Data))
}
}
request = request.WithContext(context.WithValue(request.Context(), TokenContextKey, token))
}
handlerFunc(writer, request)
}
Expand Down
2 changes: 1 addition & 1 deletion gate_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func BenchmarkGate_ProtectWithClientProviderConcurrently(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))

request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderToken))
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))

firstBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
firstBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-1"))
Expand Down
Loading

0 comments on commit 9a75e85

Please sign in to comment.