Skip to content

Commit

Permalink
api: Configurable permissions (#1314)
Browse files Browse the repository at this point in the history
* API Permissions

* dont make perms pointer

* add pre- and post-hooks

* add the first real migration

* fix migrator picking up non-instance kvs

* add func for activation

* fix api tests

* load user for token auth

* add storage client tests

* add separate call for migrations

* fix api

* make session duration configurable
  • Loading branch information
BeryJu authored Nov 24, 2024
1 parent 940d06e commit b27935a
Show file tree
Hide file tree
Showing 24 changed files with 616 additions and 67 deletions.
17 changes: 11 additions & 6 deletions pkg/instance/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,21 @@ func (i *Instance) startRole(ctx context.Context, id string, rawConfig []byte) b
defer srs.Finish()
defer i.putInstanceInfo(srs.Context())
instanceRoleStarted.WithLabelValues(id).SetToCurrentTime()
// Run migrations
client, err := i.roles[id].RoleInstance.Migrator().Run(srs.Context())
if err != nil {
i.log.Warn("failed to run migrations for role", zap.String("roleId", id))
return false
client := i.roles[id].RoleInstance.kv
if mr, ok := i.roles[id].Role.(roles.MigratableRole); ok {
mr.RegisterMigrations()
// Run migrations
_client, err := i.roles[id].RoleInstance.Migrator().Run(srs.Context())
if err != nil {
i.log.Warn("failed to run migrations for role", zap.String("roleId", id))
return false
}
client = _client
}
// Overwrite role's KV client with the potentially hooked client for migrations
i.roles[id].RoleInstance.kv = client
// Start role
err = i.roles[id].Role.Start(srs.Context(), rawConfig)
err := i.roles[id].Role.Start(srs.Context(), rawConfig)
if err == roles.ErrRoleNotConfigured {
i.log.Info("role not configured", zap.String("roleId", id))
} else if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/instance/migrate/inline_migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func MustParseConstraint(input string) *semver.Constraints {
type InlineMigration struct {
MigrationName string
ActivateOnVersion *semver.Constraints
ActivateFunc func(*semver.Version) bool
HookFunc func(context.Context) (*storage.Client, error)
CleanupFunc func(context.Context) error
}
Expand All @@ -27,6 +28,9 @@ func (im *InlineMigration) Name() string {
}

func (im *InlineMigration) Check(clusterVersion *semver.Version, ctx context.Context) (bool, error) {
if im.ActivateFunc != nil {
return im.ActivateFunc(clusterVersion), nil
}
check := im.ActivateOnVersion.Check(clusterVersion)
return check, nil
}
Expand Down
17 changes: 11 additions & 6 deletions pkg/instance/migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"sort"
"strings"

"beryju.io/gravity/pkg/extconfig"
"beryju.io/gravity/pkg/instance/types"
Expand All @@ -28,15 +29,16 @@ func New(ri roles.Instance) *Migrator {
}
}

func (mi *Migrator) GetClusterVersion() (*semver.Version, error) {
func (mi *Migrator) GetClusterVersion(ctx context.Context) (*semver.Version, error) {
type partialInstanceInfo struct {
Version string `json:"version" required:"true"`
}
pfx := mi.ri.KV().Key(
types.KeyInstance,
).Prefix(true).String()
instances, err := mi.ri.KV().Get(
context.Background(),
mi.ri.KV().Key(
types.KeyInstance,
).Prefix(true).String(),
ctx,
pfx,
clientv3.WithPrefix(),
)
if err != nil {
Expand All @@ -45,6 +47,9 @@ func (mi *Migrator) GetClusterVersion() (*semver.Version, error) {
// Gather all instances in the cluster and parse their versions
version := []*semver.Version{}
for _, inst := range instances.Kvs {
if strings.Count(strings.TrimPrefix(string(inst.Key), pfx), "/") > 0 {
continue
}
pi := partialInstanceInfo{}
err = json.Unmarshal(inst.Value, &pi)
if err != nil {
Expand All @@ -66,7 +71,7 @@ func (mi *Migrator) GetClusterVersion() (*semver.Version, error) {
}

func (mi *Migrator) Run(ctx context.Context) (*storage.Client, error) {
cv, err := mi.GetClusterVersion()
cv, err := mi.GetClusterVersion(ctx)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/instance/migrate/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ func TestMigrate_Hook(t *testing.T) {
ActivateOnVersion: migrate.MustParseConstraint("> 0.0.0"),
HookFunc: func(ctx context.Context) (*storage.Client, error) {
return ri.KV().WithHooks(storage.StorageHook{
Get: func(ctx context.Context, key string, opts ...clientv3.OpOption) error {
GetPre: func(ctx context.Context, key string, opts ...clientv3.OpOption) error {
ct += 1
return nil
},
Put: func(ctx context.Context, key, val string, opts ...clientv3.OpOption) error {
PutPre: func(ctx context.Context, key, val string, opts ...clientv3.OpOption) error {
ct += 1
return nil
},
Expand Down
45 changes: 36 additions & 9 deletions pkg/roles/api/auth/api_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ type APIUsersGetInput struct {
Username string `query:"username" description:"Optional username of a user to get"`
}
type APIUser struct {
Username string `json:"username" required:"true"`
Username string `json:"username" required:"true"`
Permissions []Permission `json:"permissions" required:"true"`
}
type APIUsersGetOutput struct {
Users []APIUser `json:"users" required:"true"`
Expand Down Expand Up @@ -47,7 +48,8 @@ func (ap *AuthProvider) APIUsersGet() usecase.Interactor {
continue
}
output.Users = append(output.Users, APIUser{
Username: u.Username,
Username: u.Username,
Permissions: u.Permissions,
})
}
return nil
Expand All @@ -62,19 +64,44 @@ func (ap *AuthProvider) APIUsersGet() usecase.Interactor {
type APIUsersPutInput struct {
Username string `query:"username" required:"true"`

Password string `json:"password" required:"true"`
Password string `json:"password" required:"true"`
Permissions []Permission `json:"permissions" required:"true"`
}

func (ap *AuthProvider) APIUsersPut() usecase.Interactor {
u := usecase.NewInteractor(func(ctx context.Context, input APIUsersPutInput, output *struct{}) error {
hash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
return status.Wrap(err, status.Internal)
rawUsers, err := ap.inst.KV().Get(
ctx,
ap.inst.KV().Key(
types.KeyRole,
types.KeyUsers,
input.Username,
).String(),
)
var oldUser *User
if err == nil && len(rawUsers.Kvs) < 1 {
user, err := ap.userFromKV(rawUsers.Kvs[0])
if err != nil {
_ = bcrypt.CompareHashAndPassword([]byte{}, []byte(input.Password))
ap.log.Warn("failed to parse user", zap.Error(err), zap.String("user", input.Username))
return status.Wrap(err, status.Internal)
}
oldUser = user
}

user := &User{
Username: input.Username,
Password: string(hash),
ap: ap,
Username: input.Username,
Permissions: input.Permissions,
ap: ap,
}
if input.Password != "" {
hash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
return status.Wrap(err, status.Internal)
}
user.Password = string(hash)
} else {
user.Password = oldUser.Password
}
err = user.put(ctx)
if err != nil {
Expand Down
25 changes: 22 additions & 3 deletions pkg/roles/api/auth/method_api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func (ap *AuthProvider) checkToken(r *http.Request) bool {
if !strings.EqualFold(parts[0], BearerType) {
return false
}
// Get token
rawTokens, err := ap.inst.KV().Get(
r.Context(),
ap.inst.KV().Key(
Expand All @@ -40,10 +41,28 @@ func (ap *AuthProvider) checkToken(r *http.Request) bool {
if err != nil {
return false
}
session := r.Context().Value(types.RequestSession).(*sessions.Session)
session.Values[types.SessionKeyUser] = User{
Username: key.Username,
// Get token's user
rawUsers, err := ap.inst.KV().Get(
r.Context(),
ap.inst.KV().Key(
types.KeyRole,
types.KeyUsers,
key.Username,
).String(),
)
if err != nil {
ap.log.Warn("failed to check token", zap.Error(err))
return false
}
if len(rawUsers.Kvs) < 1 {
return false
}
user, err := ap.userFromKV(rawUsers.Kvs[0])
if err != nil {
return false
}
session := r.Context().Value(types.RequestSession).(*sessions.Session)
session.Values[types.SessionKeyUser] = *user
session.Values[types.SessionKeyDirty] = true
return false
}
24 changes: 11 additions & 13 deletions pkg/roles/api/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,20 @@ func (ap *AuthProvider) isRequestAllowed(r *http.Request) bool {
if ap.isAllowedPath(r) {
return true
}
if ap.checkToken(r) {
return true
}
ap.checkToken(r)
session := r.Context().Value(types.RequestSession).(*sessions.Session)
u, ok := session.Values[types.SessionKeyUser]
if u != nil && ok {
hub := sentry.GetHubFromContext(r.Context())
if hub == nil {
hub = sentry.CurrentHub()
}
hub.Scope().SetUser(sentry.User{
Username: u.(User).Username,
})
return true
if u == nil || !ok {
return false
}
return false
hub := sentry.GetHubFromContext(r.Context())
if hub == nil {
hub = sentry.CurrentHub()
}
hub.Scope().SetUser(sentry.User{
Username: u.(User).Username,
})
return ap.checkPermission(r, u.(User))
}

func (ap *AuthProvider) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
Expand Down
32 changes: 32 additions & 0 deletions pkg/roles/api/auth/permission.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package auth

import (
"net/http"
"strings"
)

const wildcard = "*"

func (ap *AuthProvider) checkPermission(req *http.Request, u User) bool {
var longestMatch *Permission
for _, perm := range u.Permissions {
if strings.HasSuffix(perm.Path, wildcard) && strings.HasPrefix(req.URL.Path, strings.TrimSuffix(perm.Path, wildcard)) {
if longestMatch == nil || len(perm.Path) > len(longestMatch.Path) {
longestMatch = &perm
}
} else if perm.Path == req.URL.Path {
if longestMatch == nil || len(perm.Path) > len(longestMatch.Path) {
longestMatch = &perm
}
}
}
if longestMatch == nil {
return false
}
for _, meth := range longestMatch.Methods {
if strings.EqualFold(meth, req.Method) {
return true
}
}
return false
}
48 changes: 48 additions & 0 deletions pkg/roles/api/auth/permissions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package auth

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func mustRequest(meth string, url string) *http.Request {
req, err := http.NewRequest(meth, url, nil)
if err != nil {
panic(err)
}
return req
}

func TestPermission_Fixed(t *testing.T) {
ap := AuthProvider{}
assert.True(t, ap.checkPermission(mustRequest("get", "/foo/bar"), User{
Permissions: []Permission{
{
Path: "/foo/bar",
Methods: []string{"get", "post"},
},
{
Path: "/foo/ba",
Methods: []string{"post"},
},
{
Path: "/foo",
Methods: []string{"head"},
},
},
}))
}

func TestPermission_Wildcard(t *testing.T) {
ap := AuthProvider{}
assert.True(t, ap.checkPermission(mustRequest("get", "/foo/bar"), User{
Permissions: []Permission{
{
Path: "/foo/*",
Methods: []string{"get"},
},
},
}))
}
6 changes: 6 additions & 0 deletions pkg/roles/api/auth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ func (ap *AuthProvider) CreateUser(ctx context.Context, username, password strin

user := User{
Password: string(hashedPw),
Permissions: []Permission{
{
Path: "/*",
Methods: []string{"GET", "POST", "PUT", "HEAD", "DELETE"},
},
},
}
userJson, err := json.Marshal(user)
if err != nil {
Expand Down
13 changes: 10 additions & 3 deletions pkg/roles/api/auth/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
)

type Permission struct {
Path string `json:"path"`
Methods []string `json:"methods"`
}

type User struct {
ap *AuthProvider
Username string `json:"-"`
Password string `json:"password"`
ap *AuthProvider

Username string `json:"-"`
Password string `json:"password"`
Permissions []Permission `json:"permissions"`
}

func (u *User) String() string {
Expand Down
6 changes: 6 additions & 0 deletions pkg/roles/api/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"os"
"path"
"time"

sentryhttp "github.com/getsentry/sentry-go/http"

Expand Down Expand Up @@ -101,6 +102,11 @@ func (r *Role) Start(ctx context.Context, config []byte) error {
).String(),
cookieSecret,
)
sessDur := time.Hour * 24
if d, err := time.ParseDuration(r.cfg.SessionDuration); err == nil {
sessDur = d
}
sess.Options.MaxAge = int(sessDur.Seconds())
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit b27935a

Please sign in to comment.