Skip to content

Commit

Permalink
Optional Security for HTTP Routes (fixes #169) (#256)
Browse files Browse the repository at this point in the history
Signed-off-by: Sotirios Mantziaris <[email protected]>
  • Loading branch information
mantzas authored Dec 18, 2018
1 parent d730f13 commit 3614b19
Show file tree
Hide file tree
Showing 15 changed files with 371 additions and 41 deletions.
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,26 @@ The factory function type defines a factory for creating a logger.

```go
type FactoryFunc func(map[string]interface{}) Logger
```
```

## Security

The necessary abstraction are available to implement authentication in the following components:

- HTTP

### HTTP

In order to use authentication, a authenticator has to be implement following the interface:

```go
type Authenticator interface {
Authenticate(req *http.Request) (bool, error)
}
```

This authenticator can then be used to set up routes with authentication.

The following authenticator are available:

- API key authenticator, see examples
1 change: 1 addition & 0 deletions examples/first/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func first(ctx context.Context, req *sync.Request) (*sync.Response, error) {
}
secondRouteReq.Header.Add("Content-Type", "application/json")
secondRouteReq.Header.Add("Accept", "application/json")
secondRouteReq.Header.Add("Authorization", "Apikey 123456")
cl, err := tracehttp.New(tracehttp.Timeout(5 * time.Second))
if err != nil {
return nil, err
Expand Down
19 changes: 18 additions & 1 deletion examples/second/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/mantzas/patron/log"
"github.com/mantzas/patron/sync"
patronhttp "github.com/mantzas/patron/sync/http"
"github.com/mantzas/patron/sync/http/auth/apikey"
tracehttp "github.com/mantzas/patron/trace/http"
"github.com/mantzas/patron/trace/kafka"
"github.com/pkg/errors"
Expand Down Expand Up @@ -55,9 +56,14 @@ func main() {
log.Fatalf("failed to create processor %v", err)
}

auth, err := apikey.New(&apiKeyValidator{validKey: "123456"})
if err != nil {
log.Fatalf("failed to create authenticator %v", err)
}

// Set up routes
routes := []patronhttp.Route{
patronhttp.NewGetRoute("/", httpCmp.second, true),
patronhttp.NewAuthGetRoute("/", httpCmp.second, true, auth),
}

srv, err := patron.New(
Expand Down Expand Up @@ -122,3 +128,14 @@ func (hc *httpComponent) second(ctx context.Context, req *sync.Request) (*sync.R
log.Infof("request processed: %s", m)
return sync.NewResponse(fmt.Sprintf("got %s from google", rsp.Status)), nil
}

type apiKeyValidator struct {
validKey string
}

func (av apiKeyValidator) Validate(key string) (bool, error) {
if key == av.validKey {
return true, nil
}
return false, nil
}
2 changes: 1 addition & 1 deletion option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestRoutes(t *testing.T) {
}{
{"failure due to empty routes", args{rr: []http.Route{}}, true},
{"failure due to nil routes", args{rr: nil}, true},
{"success", args{rr: []http.Route{http.NewRoute("/", "GET", nil, true)}}, false},
{"success", args{rr: []http.Route{http.NewRoute("/", "GET", nil, true, nil)}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func TestNewServer(t *testing.T) {
route := http.NewRoute("/", "GET", nil, true)
route := http.NewRoute("/", "GET", nil, true, nil)
type args struct {
name string
opt OptionFunc
Expand Down
45 changes: 45 additions & 0 deletions sync/http/auth/apikey/apikey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package apikey

import (
"errors"
"net/http"
"strings"
)

// Validator interface for validating keys.
type Validator interface {
Validate(key string) (bool, error)
}

// Authenticator authenticates the request based on the header on the following header key and value:
// Authorization: Apikey {api key}, where {api key} is the key.
type Authenticator struct {
val Validator
}

// New constructor.
func New(val Validator) (*Authenticator, error) {
if val == nil {
return nil, errors.New("validator is nil")
}
return &Authenticator{val: val}, nil
}

// Authenticate parses the header for the specified key and authenticates it.
func (a *Authenticator) Authenticate(req *http.Request) (bool, error) {
headerVal := req.Header.Get("Authorization")
if headerVal == "" {
return false, nil
}

auth := strings.SplitN(headerVal, " ", 2)
if len(auth) != 2 {
return false, nil
}

if strings.ToLower(auth[0]) != "apikey" {
return false, nil
}

return a.val.Validate(auth[1])
}
97 changes: 97 additions & 0 deletions sync/http/auth/apikey/apikey_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package apikey

import (
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

type MockValidator struct {
err error
success bool
}

func (mv MockValidator) Validate(key string) (bool, error) {
if mv.err != nil {
return false, mv.err
}
return mv.success, nil
}

func TestNew(t *testing.T) {
type args struct {
val Validator
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "success", args: args{val: &MockValidator{}}, wantErr: false},
{name: "failed due to nil validator", args: args{val: nil}, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(tt.args.val)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
} else {
assert.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}

func TestAuthenticator_Authenticate(t *testing.T) {
reqOk, err := http.NewRequest("POST", "/test", nil)
assert.NoError(t, err)
reqOk.Header.Set("Authorization", "Apikey 123456")
reqMissingHeader, err := http.NewRequest("POST", "/test", nil)
assert.NoError(t, err)
reqMissingKey, err := http.NewRequest("POST", "/test", nil)
assert.NoError(t, err)
reqMissingKey.Header.Set("Authorization", "Apikey")
reqInvalidAuthMethod, err := http.NewRequest("POST", "/test", nil)
assert.NoError(t, err)
reqInvalidAuthMethod.Header.Set("Authorization", "Bearer 123456")

type fields struct {
val Validator
}
type args struct {
req *http.Request
}
tests := []struct {
name string
fields fields
args args
want bool
wantErr bool
}{
{name: "authenticated", fields: fields{val: &MockValidator{success: true}}, args: args{req: reqOk}, want: true, wantErr: false},
{name: "not authenticated, validation failed", fields: fields{val: &MockValidator{success: false}}, args: args{req: reqOk}, want: false, wantErr: false},
{name: "failed, validation returned err", fields: fields{val: &MockValidator{err: errors.New("TEST")}}, args: args{req: reqOk}, want: false, wantErr: true},
{name: "not authenticated, header missing", fields: fields{val: &MockValidator{success: false}}, args: args{req: reqMissingHeader}, want: false, wantErr: false},
{name: "not authenticated, missing key", fields: fields{val: &MockValidator{success: false}}, args: args{req: reqMissingKey}, want: false, wantErr: false},
{name: "not authenticated, invalid auth method", fields: fields{val: &MockValidator{success: false}}, args: args{req: reqInvalidAuthMethod}, want: false, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticator{
val: tt.fields.val,
}
got, err := a.Authenticate(tt.args.req)
if tt.wantErr {
assert.Error(t, err)
assert.False(t, got)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
10 changes: 10 additions & 0 deletions sync/http/auth/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package auth

import (
"net/http"
)

// Authenticator interface.
type Authenticator interface {
Authenticate(req *http.Request) (bool, error)
}
6 changes: 1 addition & 5 deletions sync/http/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ func (c *Component) Run(ctx context.Context) error {
c.Lock()
log.Debug("applying tracing to routes")
for i := 0; i < len(c.routes); i++ {
if c.routes[i].Trace {
c.routes[i].Handler = DefaultMiddleware(c.routes[i].Pattern, c.routes[i].Handler)
} else {
c.routes[i].Handler = RecoveryMiddleware(c.routes[i].Handler)
}
c.routes[i].Handler = Middleware(c.routes[i].Trace, c.routes[i].Auth, c.routes[i].Pattern, c.routes[i].Handler)
}
chFail := make(chan error)
srv := c.createHTTPServer()
Expand Down
16 changes: 8 additions & 8 deletions sync/http/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func TestNew(t *testing.T) {
}

func TestComponent_ListenAndServe_DefaultRoutes_Shutdown(t *testing.T) {
rr := []Route{NewRoute("/", "GET", nil, true)}
s, err := New(Routes(rr))
rr := []Route{NewRoute("/", "GET", nil, true, nil)}
s, err := New(Routes(rr), Port(50003))
assert.NoError(t, err)
done := make(chan bool)
ctx, cnl := context.WithCancel(context.Background())
Expand All @@ -56,8 +56,8 @@ func TestComponent_ListenAndServe_DefaultRoutes_Shutdown(t *testing.T) {
}

func TestComponent_ListenAndServeTLS_DefaultRoutes_Shutdown(t *testing.T) {
rr := []Route{NewRoute("/", "GET", nil, true)}
s, err := New(Routes(rr), Secure("testdata/server.pem", "testdata/server.key"))
rr := []Route{NewRoute("/", "GET", nil, true, nil)}
s, err := New(Routes(rr), Secure("testdata/server.pem", "testdata/server.key"), Port(50001))
assert.NoError(t, err)
done := make(chan bool)
ctx, cnl := context.WithCancel(context.Background())
Expand All @@ -72,12 +72,12 @@ func TestComponent_ListenAndServeTLS_DefaultRoutes_Shutdown(t *testing.T) {
}

func TestInfo(t *testing.T) {
rr := []Route{NewRoute("/", "GET", nil, true)}
s, err := New(Routes(rr), Secure("testdata/server.pem", "testdata/server.key"))
rr := []Route{NewRoute("/", "GET", nil, true, nil)}
s, err := New(Routes(rr), Secure("testdata/server.pem", "testdata/server.key"), Port(50005))
assert.NoError(t, err)
expected := make(map[string]interface{})
expected["type"] = "https"
expected["port"] = 50000
expected["port"] = 50005
expected["read-timeout"] = httpReadTimeout.String()
expected["write-timeout"] = httpWriteTimeout.String()
expected["idle-timeout"] = httpIdleTimeout.String()
Expand All @@ -87,7 +87,7 @@ func TestInfo(t *testing.T) {
}

func TestComponent_ListenAndServeTLS_FailsInvalidCerts(t *testing.T) {
rr := []Route{NewRoute("/", "GET", nil, true)}
rr := []Route{NewRoute("/", "GET", nil, true, nil)}
s, err := New(Routes(rr), Secure("testdata/server.pem", "testdata/server.pem"))
assert.NoError(t, err)
assert.Error(t, s.Run(context.Background()))
Expand Down
2 changes: 1 addition & 1 deletion sync/http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func Test_extractParams(t *testing.T) {
}

router := httprouter.New()
route := NewRoute("/users/:id/status", "GET", proc, false)
route := NewRoute("/users/:id/status", "GET", proc, false, nil)
router.HandlerFunc(route.Method, route.Pattern, route.Handler)
router.ServeHTTP(httptest.NewRecorder(), req)
assert.Equal(t, "1", fields["id"])
Expand Down
39 changes: 32 additions & 7 deletions sync/http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/mantzas/patron/errors"
"github.com/mantzas/patron/log"
"github.com/mantzas/patron/sync/http/auth"
"github.com/mantzas/patron/trace"
)

Expand Down Expand Up @@ -51,13 +52,21 @@ func (w *responseWriter) WriteHeader(code int) {
w.statusHeaderWritten = true
}

// DefaultMiddleware which handles tracing and recovery.
func DefaultMiddleware(path string, next http.HandlerFunc) http.HandlerFunc {
return TracingMiddleware(path, RecoveryMiddleware(next))
// Middleware which returns all selected middlewares.
func Middleware(trace bool, auth auth.Authenticator, path string, next http.HandlerFunc) http.HandlerFunc {
if trace {
if auth == nil {
return tracingMiddleware(path, recoveryMiddleware(next))
}
return tracingMiddleware(path, authMiddleware(auth, recoveryMiddleware(next)))
}
if auth == nil {
return recoveryMiddleware(next)
}
return authMiddleware(auth, recoveryMiddleware(next))
}

// TracingMiddleware for handling tracing and metrics.
func TracingMiddleware(path string, next http.HandlerFunc) http.HandlerFunc {
func tracingMiddleware(path string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sp, r := trace.HTTPSpan(path, r)
lw := newResponseWriter(w)
Expand All @@ -66,8 +75,7 @@ func TracingMiddleware(path string, next http.HandlerFunc) http.HandlerFunc {
}
}

// RecoveryMiddleware for recovering from failed requests.
func RecoveryMiddleware(next http.HandlerFunc) http.HandlerFunc {
func recoveryMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
Expand All @@ -88,3 +96,20 @@ func RecoveryMiddleware(next http.HandlerFunc) http.HandlerFunc {
next(w, r)
}
}

func authMiddleware(auth auth.Authenticator, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authenticated, err := auth.Authenticate(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

if !authenticated {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

next(w, r)
}
}
Loading

0 comments on commit 3614b19

Please sign in to comment.