From b6d18eb4ccd5a9b20c3732bbbdfd2cf95a546074 Mon Sep 17 00:00:00 2001 From: Igor Beliakov <46579601+weisdd@users.noreply.github.com> Date: Sun, 1 May 2022 16:34:33 +0200 Subject: [PATCH] Added more tests (#29) - Key changes: - Minor improvements in docs; - Minor improvements in logging; - Added more tests; - `VictoriaMetrics/metricsql` bumped from `0.42.0` to `0.43.0`. --- CHANGELOG.md | 8 + README.md | 7 +- Taskfile.yml | 6 + internal/lfgw/helpers.go | 12 +- internal/lfgw/logging.go | 1 + internal/lfgw/main.go | 87 +++++--- internal/lfgw/middleware.go | 32 ++- internal/lfgw/middleware_test.go | 315 +++++++++++++++++++++++++++- internal/lfgw/server.go | 5 + internal/querymodifier/acls.go | 7 +- internal/querymodifier/acls_test.go | 6 + 11 files changed, 439 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1aec77f..c79472b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # CHANGELOG +## 0.12.2 + +- Key changes: + - Minor improvements in docs; + - Minor improvements in logging; + - Added more tests; + - `VictoriaMetrics/metricsql` bumped from `0.42.0` to `0.43.0`. + ## 0.12.1 - Key changes: diff --git a/README.md b/README.md index dbd9855..0a5278b 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,12 @@ Docker images are published on [ghcr.io/weisdd/lfgw](https://github.com/weisdd/l ## Configuration -Example of `keycloak + grafana + lfgw` setup is described [here](./docs/oidc.md). +Example of `keycloak + grafana + lfgw` setup is described [here](docs/oidc.md). ### Requirements for jwt-tokens * OIDC-roles must be present in `roles` claim; -* Client ID specified via `OIDC_CLIENT_ID` must be present in `aud` claim (more details in [environment variables section](#Environment variables)), otherwise token verification will fail. +* Client ID specified via `OIDC_CLIENT_ID` must be present in `aud` claim (more details in [environment variables section](#environment-variables)), otherwise token verification will fail. ### Environment variables @@ -105,11 +105,10 @@ Note: a user is free to have multiple roles matching the contents of `acl.yaml`. ## Licensing -lfgw code is licensed under MIT, though its dependencies might have other licenses. Please, inspect the modules listed in [go.mod](./go.mod) if needed. +lfgw code is licensed under MIT, though its dependencies might have other licenses. Please, inspect the modules listed in [go.mod](go.mod) if needed. ## TODO -* tests for handlers; * improve naming; * log slow requests; * metrics; diff --git a/Taskfile.yml b/Taskfile.yml index 787494f..8ad7925 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -18,3 +18,9 @@ tasks: tidy: cmds: - go mod tidy + + cov: + cmds: + - mkdir -p temp + - go test -coverprofile=temp/coverage.out ./... + - go tool cover -html=temp/coverage.out diff --git a/internal/lfgw/helpers.go b/internal/lfgw/helpers.go index 01cf9a7..1b9d15d 100644 --- a/internal/lfgw/helpers.go +++ b/internal/lfgw/helpers.go @@ -43,7 +43,17 @@ func (app *application) getRawAccessToken(r *http.Request) (string, error) { } } - return "", fmt.Errorf("no bearer token found") + var err error + + isGrafanaRequest := strings.Contains(strings.ToLower(r.UserAgent()), "grafana") + + if isGrafanaRequest { + err = fmt.Errorf("no bearer token found, possible causes: grafana data source is not configured with Forward Oauth Identity option; grafana user sessions are not tuned to live shorter than IDP sessions; malicious requests") + } else { + err = fmt.Errorf("no bearer token found") + } + + return "", err } // isNotAPIRequest returns true if the requested path does not target API or federate endpoints. diff --git a/internal/lfgw/logging.go b/internal/lfgw/logging.go index 06ee2d5..0e28e67 100644 --- a/internal/lfgw/logging.go +++ b/internal/lfgw/logging.go @@ -31,6 +31,7 @@ func (s stdErrorLogWrapper) Write(p []byte) (n int, err error) { return len(p), nil } +// configureLogging configures zerolog and sets the respective fields in the application struct func (app *application) configureLogging() { zlog.Logger = zlog.Output(os.Stdout) app.logger = &zlog.Logger diff --git a/internal/lfgw/main.go b/internal/lfgw/main.go index 3daf1b9..3dca22c 100644 --- a/internal/lfgw/main.go +++ b/internal/lfgw/main.go @@ -44,6 +44,18 @@ type application struct { logger *zerolog.Logger } +// Run is used as an entrypoint for cli +func Run(c *cli.Context) error { + app, err := newApplication(c) + if err != nil { + return err + } + + app.Run() + + return nil +} + // newApplication returns application struct built from *cli.Context func newApplication(c *cli.Context) (application, error) { upstreamURL, err := url.Parse(c.String("upstream-url")) @@ -75,22 +87,13 @@ func newApplication(c *cli.Context) (application, error) { return app, nil } -// Run is used as an entrypoint for cli -func Run(c *cli.Context) error { - app, err := newApplication(c) - if err != nil { - return err - } - - app.Run() - - return nil -} - // Run starts lfgw (main-like function) func (app *application) Run() { app.configureLogging() + app.configureACLs() + app.configureOIDCVerifier() + // TODO: expose undo and move to another function? if app.SetGomaxProcs { undo, err := maxprocs.Set() defer undo() @@ -102,6 +105,20 @@ func (app *application) Run() { app.logger.Info().Caller(). Msgf("Runtime settings: GOMAXPROCS = %d", runtime.GOMAXPROCS(0)) + err := app.serve() + if err != nil { + app.logger.Fatal().Caller(). + Err(err).Msg("") + } +} + +// configureACLs logs assumed roles mode, verifies current ACLs settings (assumed roles, aclpath), loads the ACLs from a file and logs roles if needed +func (app *application) configureACLs() { + // Just to make sure our logging calls are always safe + if app.logger == nil { + app.configureLogging() + } + if app.AssumedRolesEnabled { app.logger.Info().Caller(). Msg("Assumed roles mode is on") @@ -110,20 +127,7 @@ func (app *application) Run() { Msg("Assumed roles mode is off") } - var err error - - if app.ACLPath != "" { - app.ACLs, err = querymodifier.NewACLsFromFile(app.ACLPath) - if err != nil { - app.logger.Fatal().Caller(). - Err(err).Msgf("Failed to load ACL") - } - - for role, acl := range app.ACLs { - app.logger.Info().Caller(). - Msgf("Loaded role definition for %s: %q (converted to %s)", role, acl.RawACL, acl.LabelFilter.AppendString(nil)) - } - } else { + if app.ACLPath == "" { // NOTE: the condition should never happen as it's filtered out by "Before" functionality of cli, though left just in case if !app.AssumedRolesEnabled { app.logger.Fatal().Caller(). @@ -131,7 +135,30 @@ func (app *application) Run() { } app.logger.Info().Caller(). - Msgf("ACL_PATH is empty, thus predefined roles are not loaded") + Msgf("ACL_PATH is empty, thus predefined roles will not be loaded") + + return + } + + var err error + + app.ACLs, err = querymodifier.NewACLsFromFile(app.ACLPath) + if err != nil { + app.logger.Fatal().Caller(). + Err(err).Msgf("Failed to load ACL") + } + + for role, acl := range app.ACLs { + app.logger.Info().Caller(). + Msgf("Loaded role definition for %s: %q (converted to %s)", role, acl.RawACL, acl.LabelFilter.AppendString(nil)) + } +} + +// configureOIDCVerifier sets up OIDC token verifier by using app.OIDCRealmURL and app.OIDCClientID +func (app *application) configureOIDCVerifier() { + // Just to make sure our logging calls are always safe + if app.logger == nil { + app.configureLogging() } app.logger.Info().Caller(). @@ -149,10 +176,4 @@ func (app *application) Run() { ClientID: app.OIDCClientID, } app.verifier = provider.Verifier(oidcConfig) - - err = app.serve() - if err != nil { - app.logger.Fatal().Caller(). - Err(err).Msg("") - } } diff --git a/internal/lfgw/middleware.go b/internal/lfgw/middleware.go index 0ac0d58..e23a6cf 100644 --- a/internal/lfgw/middleware.go +++ b/internal/lfgw/middleware.go @@ -55,6 +55,10 @@ func (app *application) logMiddleware(next http.Handler) http.Handler { // If any of those are empty, they won't get logged app.enrichDebugLogContext(r, "get_params", app.unescapedURLQuery(r.URL.Query().Encode())) app.enrichDebugLogContext(r, "post_params", app.unescapedURLQuery(postForm)) + + // Workaround to make further r.ParseForm() calls update r.Form and r.PostForm again, might be useful in case there's another middleware before rewriteRequestMiddleware + r.Form = nil + r.PostForm = nil } if app.LogRequests || app.Debug { @@ -106,6 +110,10 @@ func (app *application) oidcModeMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawAccessToken, err := app.getRawAccessToken(r) if err != nil { + // Might produce plenty of error messages, though it will make it much easier to understand why requests are failing + hlog.FromRequest(r).Error().Caller(). + Err(err).Msg("") + app.clientErrorMessage(w, http.StatusUnauthorized, err) return } @@ -158,16 +166,15 @@ func (app *application) oidcModeMiddleware(next http.Handler) http.Handler { // rewriteRequestMiddleware rewrites a request before forwarding it to the upstream. func (app *application) rewriteRequestMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Rewrite request destination - r.Host = app.UpstreamURL.Host - - if app.isNotAPIRequest(r.URL.Path) { - hlog.FromRequest(r).Debug().Caller(). - Msg("Not an API request, request is not modified") - next.ServeHTTP(w, r) + // TODO: rewrite? + if app.UpstreamURL == nil { + app.serverError(w, r, fmt.Errorf("UpstreamURL is not initialized")) return } + // Rewrite request destination + r.Host = app.UpstreamURL.Host + acl, ok := r.Context().Value(contextKeyACL).(querymodifier.ACL) if !ok { // Should never happen. It means OIDC middleware hasn't done it's job @@ -175,6 +182,13 @@ func (app *application) rewriteRequestMiddleware(next http.Handler) http.Handler return } + if app.isNotAPIRequest(r.URL.Path) { + hlog.FromRequest(r).Debug().Caller(). + Msg("Not an API request, request is not modified") + next.ServeHTTP(w, r) + return + } + if acl.Fullaccess { hlog.FromRequest(r).Debug().Caller(). Msg("User has full access, request is not modified") @@ -219,6 +233,10 @@ func (app *application) rewriteRequestMiddleware(next http.Handler) http.Handler // TODO: the field name is slightly misleading, should, probably, be renamed app.enrichDebugLogContext(r, "new_post_params", app.unescapedURLQuery(newPostParams)) + // Workaround to make further r.ParseForm() calls update r.Form and r.PostForm again, might be useful in case there's another middleware before rewriteRequestMiddleware + r.Form = nil + r.PostForm = nil + next.ServeHTTP(w, r) }) } diff --git a/internal/lfgw/middleware_test.go b/internal/lfgw/middleware_test.go index 03214b5..c928ea5 100644 --- a/internal/lfgw/middleware_test.go +++ b/internal/lfgw/middleware_test.go @@ -1,14 +1,23 @@ package lfgw import ( + "context" + "fmt" + "io" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/weisdd/lfgw/internal/querymodifier" ) +// TODO: logMiddleware add a test https://go.dev/src/net/http/httputil/reverseproxy_test.go +// to make sure such errors don't happen: reverseproxy.go:489 > error="http: proxy error: net/http: HTTP/1.x transport connection broken: http: ContentLength=57 with Body length 0\n" + func Test_safeModeMiddleware(t *testing.T) { tests := []struct { name string @@ -69,14 +78,16 @@ func Test_safeModeMiddleware(t *testing.T) { SafeMode: tt.safeMode, } - rr := httptest.NewRecorder() r, err := http.NewRequest(tt.method, tt.path, nil) if err != nil { t.Fatal(err) } + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK")) }) + + rr := httptest.NewRecorder() app.safeModeMiddleware(next).ServeHTTP(rr, r) rs := rr.Result() got := rs.StatusCode @@ -87,3 +98,305 @@ func Test_safeModeMiddleware(t *testing.T) { }) } } + +func Test_proxyHeadersMiddleware(t *testing.T) { + // Just to hold reference values + headers := map[string]string{ + "X-Forwarded-For": "1.2.3.4", + "X-Forwarded-Proto": "http", + "X-Forwarded-Host": "lfgw", + } + + // Set the values that will be used by middleware to set new headers in case app.SetProxyHeaders = true + r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s://lfgw", headers["X-Forwarded-Proto"]), nil) + if err != nil { + t.Fatal(err) + } + r.Header.Set("Host", headers["X-Forwarded-Host"]) + r.RemoteAddr = headers["X-Forwarded-For"] + + t.Run("Proxy headers are set", func(t *testing.T) { + app := &application{ + SetProxyHeaders: true, + } + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for h, want := range headers { + got := r.Header.Get(h) + assert.Equal(t, want, got, fmt.Sprintf("%s is set to a different value", h)) + } + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + // Better to clone the request to make sure tests don't interfere with each other + app.proxyHeadersMiddleware(next).ServeHTTP(rr, r.Clone(r.Context())) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusOK + + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + t.Run("Proxy headers are NOT set", func(t *testing.T) { + app := &application{ + SetProxyHeaders: false, + } + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for h := range headers { + assert.Empty(t, r.Header.Get(h), fmt.Sprintf("%s must be empty", h)) + } + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + // Better to clone the request to make sure tests don't interfere with each other + app.proxyHeadersMiddleware(next).ServeHTTP(rr, r.Clone(r.Context())) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusOK + + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + +} + +func Test_rewriteRequestMiddleware(t *testing.T) { + logger := zerolog.New(nil) + + t.Run("UpstreamURL is not set", func(t *testing.T) { + app := &application{ + logger: &logger, + UpstreamURL: nil, + } + + r, err := http.NewRequest(http.MethodGet, "http://lfgw/api/v1/federate", nil) + if err != nil { + t.Fatal(err) + } + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + app.rewriteRequestMiddleware(next).ServeHTTP(rr, r) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusInternalServerError + + // TODO: check logs for the error message? + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + // TODO:rewrite once something is done with app.UpstreamURL + upstreamURL, err := url.Parse("http://prometheus") + assert.Nil(t, err) + + app := &application{ + logger: &logger, + UpstreamURL: upstreamURL, + } + + t.Run("ACL is not in the context", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://lfgw/api/v1/federate", nil) + if err != nil { + t.Fatal(err) + } + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + app.rewriteRequestMiddleware(next).ServeHTTP(rr, r) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusInternalServerError + + // TODO: check logs for the error message? + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + t.Run("Not an API request", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://lfgw/fakeapi/v1/query?query=kube_pod_info", nil) + if err != nil { + t.Fatal(err) + } + + acl, err := querymodifier.NewACL("monitoring") + assert.Nil(t, err) + + ctx := context.WithValue(r.Context(), contextKeyACL, acl) + r = r.WithContext(ctx) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got, err := url.QueryUnescape(r.URL.RawQuery) + assert.Nil(t, err) + + want := "query=kube_pod_info" + assert.Equal(t, want, got) + + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + app.rewriteRequestMiddleware(next).ServeHTTP(rr, r) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusOK + + // TODO: check logs for the error message? + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + t.Run("User has full access, API request is not modified", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://lfgw/api/v1/query?query=kube_pod_info", nil) + if err != nil { + t.Fatal(err) + } + + acl, err := querymodifier.NewACL(".*") + assert.Nil(t, err) + + ctx := context.WithValue(r.Context(), contextKeyACL, acl) + r = r.WithContext(ctx) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got, err := url.QueryUnescape(r.URL.RawQuery) + assert.Nil(t, err) + + want := "query=kube_pod_info" + assert.Equal(t, want, got) + + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + app.rewriteRequestMiddleware(next).ServeHTTP(rr, r) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusOK + + // TODO: check logs for the message? + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + // TODO: merge GET & POST tests? + + t.Run("API request is modified according to an ACL (GET)", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://lfgw/api/v1/query?query=kube_pod_info", nil) + if err != nil { + t.Fatal(err) + } + + acl, err := querymodifier.NewACL("monitoring") + assert.Nil(t, err) + + ctx := context.WithValue(r.Context(), contextKeyACL, acl) + r = r.WithContext(ctx) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Workaround to make r.ParseForm update r.Form and r.PostForm again + r.Form = nil + r.PostForm = nil + + err := r.ParseForm() + assert.Nil(t, err) + + want := url.Values{ + "query": {`kube_pod_info{namespace="monitoring"}`}, + } + got := r.Form + + assert.Equal(t, want, got) + + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + app.rewriteRequestMiddleware(next).ServeHTTP(rr, r) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusOK + + // TODO: check logs for the error message? + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + t.Run("API request is modified according to an ACL (POST)", func(t *testing.T) { + body := io.NopCloser(strings.NewReader("query=kube_pod_info")) + + r, err := http.NewRequest(http.MethodPost, "http://lfgw/api/v1/query", body) + if err != nil { + t.Fatal(err) + } + + // Requests of a different type are not decoded by r.ParseForm() + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + acl, err := querymodifier.NewACL("monitoring") + assert.Nil(t, err) + + ctx := context.WithValue(r.Context(), contextKeyACL, acl) + r = r.WithContext(ctx) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Workaround to make r.ParseForm update r.Form and r.PostForm again + r.Form = nil + r.PostForm = nil + + err := r.ParseForm() + assert.Nil(t, err) + + want := url.Values{ + "query": {`kube_pod_info{namespace="monitoring"}`}, + } + got := r.PostForm + + assert.Equal(t, want, got) + + postForm := r.PostForm.Encode() + newBody := strings.NewReader(postForm) + r.ContentLength = newBody.Size() + r.Body = io.NopCloser(newBody) + + _, _ = w.Write([]byte("OK")) + }) + + rr := httptest.NewRecorder() + app.rewriteRequestMiddleware(next).ServeHTTP(rr, r) + rs := rr.Result() + + got := rs.StatusCode + want := http.StatusOK + + // TODO: check logs for the error message? + assert.Equal(t, want, got) + + defer rs.Body.Close() + }) + + // TODO: log fields are added (both get / post) +} diff --git a/internal/lfgw/server.go b/internal/lfgw/server.go index ba8c95e..be66313 100644 --- a/internal/lfgw/server.go +++ b/internal/lfgw/server.go @@ -14,6 +14,11 @@ import ( // serve starts a web server and ensures graceful shutdown func (app *application) serve() error { + // Just to make sure our logging calls are always safe + if app.logger == nil { + app.configureLogging() + } + app.proxy = httputil.NewSingleHostReverseProxy(app.UpstreamURL) // TODO: somehow pass more context to ErrorLog (unsafe?) app.proxy.ErrorLog = app.errorLog diff --git a/internal/querymodifier/acls.go b/internal/querymodifier/acls.go index 8c1b4de..64c9c05 100644 --- a/internal/querymodifier/acls.go +++ b/internal/querymodifier/acls.go @@ -85,10 +85,15 @@ func (a ACLs) GetUserACL(oidcRoles []string, assumedRolesEnabled bool) (ACL, err return acl, nil } -// NewACLsFromFile loads ACL from a file +// NewACLsFromFile loads ACL from a file or returns an empty ACLs instance if path is empty func NewACLsFromFile(path string) (ACLs, error) { acls := make(ACLs) + path = strings.TrimSpace(path) + if path == "" { + return acls, nil + } + yamlFile, err := os.ReadFile(path) if err != nil { return ACLs{}, err diff --git a/internal/querymodifier/acls_test.go b/internal/querymodifier/acls_test.go index 22e08af..a2d883f 100644 --- a/internal/querymodifier/acls_test.go +++ b/internal/querymodifier/acls_test.go @@ -310,6 +310,12 @@ func TestACL_NewACLsFromFile(t *testing.T) { }) } + t.Run("empty path", func(t *testing.T) { + got, err := NewACLsFromFile("") + assert.Nil(t, err) + assert.Equal(t, ACLs{}, got) + }) + t.Run("incorrect ACL", func(t *testing.T) { saveACLToFile(t, f, "test-role:") _, err := NewACLsFromFile(f.Name())