From d80f9906b59be4b1529f63a945e2dd9b87543ccb Mon Sep 17 00:00:00 2001 From: Mustafa Saber Date: Wed, 31 May 2023 13:33:04 +0300 Subject: [PATCH] Support `Etag` from DataClient prespective (#2304) * Support `Etag` from DataClient prespective Signed-off-by: Mustafa Abdelrahman --- eskipfile/remote.go | 70 +++++++---- eskipfile/remote_test.go | 251 ++++++++++++++++++++++++++------------- 2 files changed, 220 insertions(+), 101 deletions(-) diff --git a/eskipfile/remote.go b/eskipfile/remote.go index fe8041fd05..02de85bfa3 100644 --- a/eskipfile/remote.go +++ b/eskipfile/remote.go @@ -2,7 +2,9 @@ package eskipfile import ( "errors" + "fmt" "io" + "net/http" "os" "strings" "sync" @@ -15,6 +17,8 @@ import ( log "github.com/sirupsen/logrus" ) +var errContentNotChanged = errors.New("content in cache did not change, 304 reponse status code") + type remoteEskipFile struct { once sync.Once preloaded bool @@ -24,6 +28,7 @@ type remoteEskipFile struct { threshold int verbose bool http *net.Client + etag string } type RemoteWatchOptions struct { @@ -122,19 +127,21 @@ func (client *remoteEskipFile) LoadUpdate() ([]*eskip.Route, []string, error) { } newRoutes, deletedRoutes, err := client.eskipFileClient.LoadUpdate() - if err == nil { - if client.verbose { - log.Infof("New routes were loaded. New: %d; deleted: %d", len(newRoutes), len(deletedRoutes)) - - if client.threshold > 0 { - if len(newRoutes)+len(deletedRoutes) > client.threshold { - log.Warnf("Significant amount of routes was updated. New: %d; deleted: %d", len(newRoutes), len(deletedRoutes)) - } - } - } - } else { + + if err != nil { log.Errorf("RemoteEskipFile LoadUpdate %s failed. Skipper continues to serve the last successfully updated routes. Error: %s", client.remotePath, err) + return newRoutes, deletedRoutes, err + } + + if client.verbose { + log.Infof("New routes were loaded. New: %d; deleted: %d", len(newRoutes), len(deletedRoutes)) + + if client.threshold > 0 { + if len(newRoutes)+len(deletedRoutes) > client.threshold { + log.Warnf("Significant amount of routes was updated. New: %d; deleted: %d", len(newRoutes), len(deletedRoutes)) + } + } } return newRoutes, deletedRoutes, err @@ -152,33 +159,56 @@ func isFileRemote(remotePath string) bool { } func (client *remoteEskipFile) DownloadRemoteFile() error { - data, err := client.getRemoteData() + resBody, err := client.getRemoteData() if err != nil { + if errors.Is(err, errContentNotChanged) { + return nil + } return err } - defer data.Close() + defer resBody.Close() - out, err := os.OpenFile(client.localPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + outFile, err := os.OpenFile(client.localPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } + defer outFile.Close() - if _, err = io.Copy(out, data); err != nil { - _ = out.Close() + if _, err = io.Copy(outFile, resBody); err != nil { + _ = outFile.Close() return err } - return out.Close() + return outFile.Close() } func (client *remoteEskipFile) getRemoteData() (io.ReadCloser, error) { - resp, err := client.http.Get(client.remotePath) + req, err := http.NewRequest("GET", client.remotePath, nil) + + if err != nil { + return nil, err + } + + if client.etag != "" { + req.Header.Set("If-None-Match", client.etag) + } + + resp, err := client.http.Do(req) if err != nil { return nil, err } + + if client.etag != "" && resp.StatusCode == 304 { + resp.Body.Close() + return nil, errContentNotChanged + } + if resp.StatusCode != 200 { - return nil, errors.New("download file failed") + resp.Body.Close() + return nil, fmt.Errorf("failed to download remote file %s, status code: %d", client.remotePath, resp.StatusCode) } - return resp.Body, nil + client.etag = resp.Header.Get("ETag") + + return resp.Body, err } diff --git a/eskipfile/remote_test.go b/eskipfile/remote_test.go index 92aa5495de..e0530027eb 100644 --- a/eskipfile/remote_test.go +++ b/eskipfile/remote_test.go @@ -6,10 +6,12 @@ import ( "net" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/zalando/skipper/eskip" ) @@ -44,11 +46,7 @@ func TestIsRemoteFile(t *testing.T) { } { t.Run(test.title, func(t *testing.T) { result := isFileRemote(test.file) - - if result != test.expected { - t.Error("isRemoteFile failed") - t.Log(test) - } + assert.Equal(t, result, test.expected) }) } } @@ -69,107 +67,100 @@ func TestLoadAll(t *testing.T) { title: "Download valid remote file", routeContent: fmt.Sprintf("VALID: %v;", routeBody), routeStatusCode: 200, - expected: []*eskip.Route{{ - Id: "VALID", - Path: "/", - Filters: []*eskip.Filter{{ - Name: "setPath", - Args: []interface{}{ - "/homepage/", - }, - }}, - BackendType: eskip.NetworkBackend, - Shunt: false, - Backend: "https://example.com/", - }}, + expected: eskip.MustParse(fmt.Sprintf("VALID: %v;", routeBody)), }, } { - s := createTestServer(test.routeContent, test.routeStatusCode) - defer s.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(test.routeStatusCode) + io.WriteString(w, test.routeContent) + })) + defer ts.Close() t.Run(test.title, func(t *testing.T) { - options := &RemoteWatchOptions{RemoteFile: s.URL, Threshold: 10, Verbose: true, FailOnStartup: true} + options := &RemoteWatchOptions{RemoteFile: ts.URL, Threshold: 10, Verbose: true, FailOnStartup: true} client, err := RemoteWatch(options) - if err == nil { - defer client.(*remoteEskipFile).Close() - } - if err == nil && test.fail { - t.Error("failed to fail") - return - } else if err != nil && !test.fail { - t.Error(err) - return - } else if test.fail { + if test.fail { + assert.Error(t, err) return } - r, err := client.LoadAll() - if err != nil { - t.Error(err) - return - } + require.NoError(t, err) - if len(r) == 0 { - r = nil - } + defer client.(*remoteEskipFile).Close() - if !cmp.Equal(r, test.expected) { - t.Errorf("invalid routes received\n%s", cmp.Diff(r, test.expected)) - } + r, err := client.LoadAll() + require.NoError(t, err) + + assert.Equal(t, r, test.expected) }) } } func TestLoadAllAndUpdate(t *testing.T) { for _, test := range []struct { - title string - validRouteContent string - invalidRouteContent string - expectedToFail bool - fail bool + title string + content string + contentUpdated string + expectedToFail bool + fail bool }{{ - title: "Download invalid update and all routes returns routes nil", - validRouteContent: fmt.Sprintf("VALID: %v;", routeBody), - invalidRouteContent: fmt.Sprintf("MISSING_SEMICOLON: %v\nVALID: %v;", routeBody, routeBody), - expectedToFail: true, + title: "Download invalid update and all routes returns routes nil", + content: fmt.Sprintf("VALID: %v;", routeBody), + contentUpdated: fmt.Sprintf("MISSING_SEMICOLON: %v\nVALID: %v;", routeBody, routeBody), + expectedToFail: true, + }, { + title: "Download valid update and all routes returns routes", + content: fmt.Sprintf("VALID: %v;", routeBody), + contentUpdated: fmt.Sprintf("DIFFERENT_ID: %v;\nVALID: %v;", routeBody, routeBody), + expectedToFail: false, }, } { t.Run(test.title, func(t *testing.T) { - testValidServer := createTestServer(test.validRouteContent, 200) - defer testValidServer.Close() + ch := make(chan string, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + routeString := <-ch + t.Logf("server routes: %v", routeString) + io.WriteString(w, routeString) + })) - options := &RemoteWatchOptions{RemoteFile: testValidServer.URL, Threshold: 10, Verbose: true, FailOnStartup: true} - client, err := RemoteWatch(options) - if err == nil { - defer client.(*remoteEskipFile).Close() - } + defer ts.Close() - if err == nil && test.fail { - t.Error("failed to fail") - return - } else if err != nil && !test.fail { - t.Error(err) - return - } else if test.fail { + options := &RemoteWatchOptions{RemoteFile: ts.URL, Threshold: 10, Verbose: true, FailOnStartup: true} + ch <- test.content + client, err := RemoteWatch(options) + if test.fail { + assert.Error(t, err) return } - testInvalidServer := createTestServer(test.invalidRouteContent, 200) - defer testInvalidServer.Close() + require.NoError(t, err) - client.(*remoteEskipFile).remotePath = testInvalidServer.URL - _, _, err = client.LoadUpdate() - if test.expectedToFail && err == nil { - t.Error(err) - return - } + defer client.(*remoteEskipFile).Close() + + t.Logf("local path is: %s", client.(*remoteEskipFile).localPath) + + ch <- test.content + r, err := client.LoadAll() + require.NoError(t, err) - _, err = client.LoadAll() - if test.expectedToFail && err == nil { - t.Error(err) + expected := eskip.MustParse(test.content) + + assert.Equal(t, r, expected) + + ch <- test.contentUpdated + r, _, err = client.LoadUpdate() + t.Logf("routes returned: %+v", r) + + if test.expectedToFail { + assert.Error(t, err) return } + require.NoError(t, err) + + expected = eskip.MustParse(fmt.Sprintf("DIFFERENT_ID: %v;", routeBody)) + + assert.Equal(t, r, expected) }) } } @@ -189,9 +180,107 @@ func TestHTTPTimeout(t *testing.T) { } } -func createTestServer(c string, sc int) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(sc) - io.WriteString(w, c) +func TestRoutesCaching(t *testing.T) { + count200s := atomic.Int32{} + count304s := atomic.Int32{} + expected := eskip.MustParse(fmt.Sprintf("VALID: %v;", routeBody)) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if noneMatch := r.Header.Get("If-None-Match"); noneMatch == "test-etag" { + t.Logf("request matches etag: %s", noneMatch) + w.WriteHeader(http.StatusNotModified) + count304s.Add(1) + } else { + w.Header().Set("ETag", "test-etag") + io.WriteString(w, fmt.Sprintf("VALID: %v;", routeBody)) + count200s.Add(1) + } })) + defer server.Close() + + options := &RemoteWatchOptions{RemoteFile: server.URL, Threshold: 10, Verbose: true, FailOnStartup: true} + client, err := RemoteWatch(options) // First load done with initialization because of FailOnStartup + + require.NoError(t, err) + + defer client.(*remoteEskipFile).Close() + + r, err := client.LoadAll() + + t.Logf("uncached responses received: %d", count200s.Load()) + assert.Equal(t, int32(1), count200s.Load()) + t.Logf("cached responses received: %d", count304s.Load()) + assert.Equal(t, int32(1), count304s.Load()) + + require.NoError(t, err) + + assert.Equal(t, r, expected) + +} + +func TestRoutesCachingWrongEtag(t *testing.T) { + alternate := atomic.Int32{} + count200s := atomic.Int32{} + count304s := atomic.Int32{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedEtag := "test-etag" + if alternate.Load()%2 == 0 { + expectedEtag = "different-etag" + } + if noneMatch := r.Header.Get("If-None-Match"); noneMatch == expectedEtag { + w.WriteHeader(http.StatusNotModified) + count304s.Add(1) + } else { + if alternate.Load()%2 == 0 { + w.Header().Set("ETag", "different-etag") + io.WriteString(w, fmt.Sprintf("different: %v;", routeBody)) + } else { + w.Header().Set("ETag", "test-etag") + io.WriteString(w, fmt.Sprintf("VALID: %v;", routeBody)) + } + count200s.Add(1) + } + alternate.Add(1) + })) + defer ts.Close() + + options := &RemoteWatchOptions{RemoteFile: ts.URL, Threshold: 10, Verbose: true, FailOnStartup: true} + client, err := RemoteWatch(options) + require.NoError(t, err) + + defer client.(*remoteEskipFile).Close() + + r, err := client.LoadAll() + + require.NoError(t, err) + + t.Logf("uncached responses received: %d", count200s.Load()) + assert.Equal(t, int32(2), count200s.Load()) + t.Logf("cached responses received: %d", count304s.Load()) + assert.Equal(t, int32(0), count304s.Load()) + + expected := eskip.MustParse(fmt.Sprintf("different: %v;", routeBody)) + + t.Logf("routes returned: %s", r[0].Id) + t.Logf("routes expected: %s", expected[0].Id) + + assert.NotEqual(t, r, expected) + + r, err = client.LoadAll() + + require.NoError(t, err) + + t.Logf("uncached responses received: %d", count200s.Load()) + assert.Equal(t, int32(3), count200s.Load()) + t.Logf("cached responses received: %d", count304s.Load()) + assert.Equal(t, int32(0), count304s.Load()) + + expected = eskip.MustParse(fmt.Sprintf("VALID: %v;", routeBody)) + + t.Logf("routes returned: %s", r[0].Id) + t.Logf("routes expected: %s", expected[0].Id) + + assert.NotEqual(t, r, expected) + }