From a834d901303037597d1b46de533b1ac0227bf4a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sandor=20Sz=C3=BCcs?= Date: Thu, 14 Mar 2024 00:18:23 +0100 Subject: [PATCH] feature: retry() filter feature: net.Client.Retry() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sandor Szücs --- filters/builtin/builtin.go | 2 + filters/filters.go | 1 + filters/retry/retry.go | 18 +++ filters/retry/retry_test.go | 95 ++++++++++++++++ io/copy_stream.go | 48 ++++++++ io/copy_stream_test.go | 41 +++++++ net/httpclient.go | 62 +++-------- net/httpclient_test.go | 213 +++++++++++++++++++++++++++++++++++- proxy/proxy.go | 43 ++++++++ 9 files changed, 473 insertions(+), 50 deletions(-) create mode 100644 filters/retry/retry.go create mode 100644 filters/retry/retry_test.go create mode 100644 io/copy_stream.go create mode 100644 io/copy_stream_test.go diff --git a/filters/builtin/builtin.go b/filters/builtin/builtin.go index d5c9e34f24..48820049a8 100644 --- a/filters/builtin/builtin.go +++ b/filters/builtin/builtin.go @@ -15,6 +15,7 @@ import ( "github.com/zalando/skipper/filters/fadein" "github.com/zalando/skipper/filters/flowid" logfilter "github.com/zalando/skipper/filters/log" + "github.com/zalando/skipper/filters/retry" "github.com/zalando/skipper/filters/rfc" "github.com/zalando/skipper/filters/scheduler" "github.com/zalando/skipper/filters/sed" @@ -229,6 +230,7 @@ func Filters() []filters.Spec { fadein.NewEndpointCreated(), consistenthash.NewConsistentHashKey(), consistenthash.NewConsistentHashBalanceFactor(), + retry.NewRetry(), tls.New(), } } diff --git a/filters/filters.go b/filters/filters.go index a43c4b19e4..bf7b2dd593 100644 --- a/filters/filters.go +++ b/filters/filters.go @@ -333,6 +333,7 @@ const ( FifoWithBodyName = "fifoWithBody" LifoName = "lifo" LifoGroupName = "lifoGroup" + RetryName = "retry" RfcPathName = "rfcPath" RfcHostName = "rfcHost" BearerInjectorName = "bearerinjector" diff --git a/filters/retry/retry.go b/filters/retry/retry.go new file mode 100644 index 0000000000..3890f87df4 --- /dev/null +++ b/filters/retry/retry.go @@ -0,0 +1,18 @@ +package retry + +import ( + "github.com/zalando/skipper/filters" +) + +type retry struct{} + +// NewRetry creates a filter specification for the retry() filter +func NewRetry() filters.Spec { return retry{} } + +func (retry) Name() string { return filters.RetryName } +func (retry) CreateFilter([]interface{}) (filters.Filter, error) { return retry{}, nil } +func (retry) Response(filters.FilterContext) {} + +func (retry) Request(ctx filters.FilterContext) { + ctx.StateBag()[filters.RetryName] = struct{}{} +} diff --git a/filters/retry/retry_test.go b/filters/retry/retry_test.go new file mode 100644 index 0000000000..81f97c5bbb --- /dev/null +++ b/filters/retry/retry_test.go @@ -0,0 +1,95 @@ +package retry + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/AlexanderYastrebov/noleak" + "github.com/zalando/skipper/eskip" + "github.com/zalando/skipper/filters" + "github.com/zalando/skipper/proxy/proxytest" +) + +func TestRetry(t *testing.T) { + for _, tt := range []struct { + name string + method string + body string + }{ + { + name: "test GET", + method: "GET", + }, + { + name: "test POST", + method: "POST", + body: "hello POST", + }, + { + name: "test PATCH", + method: "PATCH", + body: "hello PATCH", + }, + { + name: "test PUT", + method: "PUT", + body: "hello PUT", + }} { + t.Run(tt.name, func(t *testing.T) { + i := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if i == 0 { + i++ + w.WriteHeader(http.StatusBadGateway) + return + } + + got, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("got no data") + } + s := string(got) + if tt.body != s { + t.Fatalf("Failed to get the right data want: %q, got: %q", tt.body, s) + } + + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + noleak.Check(t) + + fr := make(filters.Registry) + retry := NewRetry() + fr.Register(retry) + r := &eskip.Route{ + Filters: []*eskip.Filter{ + {Name: retry.Name()}, + }, + Backend: backend.URL, + } + + proxy := proxytest.New(fr, r) + defer proxy.Close() + + buf := bytes.NewBufferString(tt.body) + req, err := http.NewRequest(tt.method, proxy.URL, buf) + if err != nil { + t.Fatal(err) + } + + rsp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to execute retry: %v", err) + } + + if rsp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + rsp.Body.Close() + }) + } +} diff --git a/io/copy_stream.go b/io/copy_stream.go new file mode 100644 index 0000000000..d4b0e994a8 --- /dev/null +++ b/io/copy_stream.go @@ -0,0 +1,48 @@ +package io + +import ( + "bytes" + "io" +) + +type bodyBuffer struct{ *bytes.Buffer } + +func (buf *bodyBuffer) Close() error { + return nil +} + +type CopyBodyStream struct { + left int + buf *bodyBuffer + input io.ReadCloser +} + +func NewCopyBodyStream(left int, buf *bytes.Buffer, rc io.ReadCloser) *CopyBodyStream { + return &CopyBodyStream{ + left: left, + buf: &bodyBuffer{Buffer: buf}, + input: rc, + } +} + +func (cb *CopyBodyStream) Len() int { + return cb.buf.Len() +} + +func (cb *CopyBodyStream) Read(p []byte) (n int, err error) { + n, err = cb.input.Read(p) + if cb.left > 0 && n > 0 { + m := min(n, cb.left) + cb.buf.Write(p[:m]) + cb.left -= m + } + return n, err +} + +func (cb *CopyBodyStream) Close() error { + return cb.input.Close() +} + +func (cb *CopyBodyStream) GetBody() io.ReadCloser { + return cb.buf +} diff --git a/io/copy_stream_test.go b/io/copy_stream_test.go new file mode 100644 index 0000000000..8ea2f0ded2 --- /dev/null +++ b/io/copy_stream_test.go @@ -0,0 +1,41 @@ +package io + +import ( + "bytes" + "io" + "testing" +) + +type tbuf struct{ *bytes.Buffer } + +func (tb *tbuf) Read(p []byte) (int, error) { + return tb.Buffer.Read(p) +} +func (tb *tbuf) Close() error { + return nil +} + +func TestCopyBodyStream(t *testing.T) { + s := "content" + bbuf := &tbuf{bytes.NewBufferString(s)} + cbs := NewCopyBodyStream(bbuf.Len(), &bytes.Buffer{}, bbuf) + + buf := make([]byte, len(s)) + cbs.Read(buf) + + if cbs.Len() != len(buf) { + t.Fatalf("Failed to have the same buf buffer size want: %d, got: %d", cbs.Len(), len(buf)) + } + + got, err := io.ReadAll(cbs.GetBody()) + if err != nil { + t.Fatalf("Failed to read: %v", err) + } + if gotStr := string(got); s != gotStr { + t.Fatalf("Failed to get the right content: %s != %s", s, gotStr) + } + + if err = cbs.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/net/httpclient.go b/net/httpclient.go index 833fab3898..e52f00966d 100644 --- a/net/httpclient.go +++ b/net/httpclient.go @@ -3,6 +3,7 @@ package net import ( "bytes" "crypto/tls" + "errors" "fmt" "io" "net/http" @@ -15,6 +16,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" + skpio "github.com/zalando/skipper/io" "github.com/zalando/skipper/logging" "github.com/zalando/skipper/secrets" ) @@ -24,43 +26,7 @@ const ( defaultRefreshInterval = 5 * time.Minute ) -type mybuf struct{ *bytes.Buffer } - -func (buf *mybuf) Close() error { - return nil -} - -type copyBodyStream struct { - left int - buf *mybuf - input io.ReadCloser -} - -func newCopyBodyStream(left int, buf *bytes.Buffer, rc io.ReadCloser) *copyBodyStream { - return ©BodyStream{ - left: left, - buf: &mybuf{Buffer: buf}, - input: rc, - } -} - -func (cb *copyBodyStream) Read(p []byte) (n int, err error) { - n, err = cb.input.Read(p) - if cb.left > 0 && n > 0 { - m := min(n, cb.left) - cb.buf.Write(p[:m]) - cb.left -= m - } - return n, err -} - -func (cb *copyBodyStream) Close() error { - return cb.input.Close() -} - -func (cb *copyBodyStream) GetBody() io.ReadCloser { - return cb.buf -} +var errRequestNotFound = errors.New("request not found") // Client adds additional features like Bearer token injection, and // opentracing to the wrapped http.Client with the same interface as @@ -166,8 +132,8 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Bearer "+string(b)) } } - if req.Body != nil && req.Body != http.NoBody { - retryBuffer := newCopyBodyStream(int(req.ContentLength), &bytes.Buffer{}, req.Body) + if req.Body != nil && req.Body != http.NoBody && req.ContentLength > 0 { + retryBuffer := skpio.NewCopyBodyStream(int(req.ContentLength), &bytes.Buffer{}, req.Body) c.retryBuffers.Store(req, retryBuffer) req.Body = retryBuffer } @@ -179,20 +145,20 @@ func (c *Client) Retry(req *http.Request) (*http.Response, error) { return c.Do(req) } - if rc, err := req.GetBody(); err == nil { - println("req.GetBody() case") - c.retryBuffers.Delete(req) - req.Body = rc - return c.Do(req) - } + // Next line panics on TestClientRetryBodyHalfReader + // if rc, err := req.GetBody(); err == nil { + // c.retryBuffers.Delete(req) + // req.Body = rc + // return c.Do(req) + // } + // return nil, fmt.Errorf("failed to retry") - println("our own retry buffer impl") buf, ok := c.retryBuffers.Load(req) if !ok { - return nil, fmt.Errorf("no retry possible, request not found: %s %s", req.Method, req.URL) + return nil, fmt.Errorf("no retry possible, %w: %s %s", errRequestNotFound, req.Method, req.URL) } - retryBuffer, ok := buf.(*copyBodyStream) + retryBuffer, ok := buf.(*skpio.CopyBodyStream) if !ok { return nil, fmt.Errorf("no retry possible, no retry buffer for request: %s %s", req.Method, req.URL) } diff --git a/net/httpclient_test.go b/net/httpclient_test.go index f744e9984a..176e6592ec 100644 --- a/net/httpclient_test.go +++ b/net/httpclient_test.go @@ -2,6 +2,7 @@ package net import ( "bytes" + "errors" "io" "net/http" "net/http/httptest" @@ -9,6 +10,7 @@ import ( "os" "path/filepath" "testing" + "testing/iotest" "time" "github.com/AlexanderYastrebov/noleak" @@ -327,7 +329,7 @@ func TestClientClosesIdleConnections(t *testing.T) { rsp.Body.Close() } -func TestTestClientRetry(t *testing.T) { +func TestClientRetry(t *testing.T) { for _, tt := range []struct { name string method string @@ -358,6 +360,7 @@ func TestTestClientRetry(t *testing.T) { if i == 0 { i++ w.WriteHeader(http.StatusBadGateway) + return } got, err := io.ReadAll(r.Body) @@ -404,7 +407,7 @@ func TestTestClientRetry(t *testing.T) { } } -func TestTestClientRetryConcurrentRequests(t *testing.T) { +func TestClientRetryConcurrentRequests(t *testing.T) { for _, tt := range []struct { name string method string @@ -500,3 +503,209 @@ func TestTestClientRetryConcurrentRequests(t *testing.T) { }) } } + +func TestClientRetryFailConcurrentRequests(t *testing.T) { + for _, tt := range []struct { + name string + method string + body string + }{ + { + name: "test GET", + method: "GET", + }, + { + name: "test POST", + method: "POST", + body: "hello POST", + }, + { + name: "test PATCH", + method: "PATCH", + body: "hello PATCH", + }, + { + name: "test PUT", + method: "PUT", + body: "hello PUT", + }} { + t.Run(tt.name, func(t *testing.T) { + i := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ignore" { + w.WriteHeader(http.StatusOK) + return + } + + if i < 3 { + i++ + io.ReadAll(r.Body) + w.WriteHeader(http.StatusBadGateway) + return + } + + got, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("got no data") + } + s := string(got) + if tt.body != s { + t.Fatalf("Failed to get the right data want: %q, got: %q", tt.body, s) + } + + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + noleak.Check(t) + + cli := NewClient(Options{}) + defer cli.Close() + + quit := make(chan struct{}) + go func() { + for { + select { + case <-quit: + return + default: + } + cli.Get(backend.URL + "/ignore") + } + }() + + buf := bytes.NewBufferString(tt.body) + req, err := http.NewRequest(tt.method, backend.URL, buf) + if err != nil { + t.Fatal(err) + } + rsp, err := cli.Do(req) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + + for i := 0; i < 2; i++ { + rsp, err = cli.Retry(req) + if err != nil { + t.Fatalf("Failed to execute retry: %v", err) + } + if rsp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + } + + rsp, err = cli.Retry(req) + if err != nil { + t.Fatalf("Failed to execute retry: %v", err) + } + if rsp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + rsp.Body.Close() + + close(quit) + }) + } +} + +type halfReader struct { + r io.Reader +} + +func newHalfReader(r io.Reader) *halfReader { + return &halfReader{ + r: iotest.HalfReader(r), + } +} + +func (hr *halfReader) Read(p []byte) (int, error) { + return hr.r.Read(p) +} + +func TestClientRetryBodyHalfReader(t *testing.T) { + for _, tt := range []struct { + name string + method string + body string + }{ + { + name: "test GET", + method: "GET", + }, + { + name: "test POST", + method: "POST", + body: "hello POST", + }, + { + name: "test PATCH", + method: "PATCH", + body: "hello PATCH", + }, + { + name: "test PUT", + method: "PUT", + body: "hello PUT", + }} { + t.Run(tt.name, func(t *testing.T) { + i := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if i == 0 { + i++ + w.WriteHeader(http.StatusBadGateway) + return + } + + got, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("got no data") + } + + s := string(got) + if len(s) != 0 { + t.Fatalf("Failed to get the right data got: %q", s) + } + + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + noleak.Check(t) + + cli := NewClient(Options{}) + defer cli.Close() + + b := bytes.NewBufferString(tt.body) + buf := newHalfReader(b) + + req, err := http.NewRequest(tt.method, backend.URL, buf) + if err != nil { + t.Fatal(err) + } + rsp, err := cli.Do(req) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + + rsp, err = cli.Retry(req) + if err != nil { + if !errors.Is(err, errRequestNotFound) { + t.Fatalf("Failed to execute retry: %v", err) + } else { + return + } + } + + if rsp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + rsp.Body.Close() + }) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index d2156d8f55..98db383e56 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1214,8 +1214,16 @@ func (p *Proxy) do(ctx *context, parentSpan ot.Span) (err error) { return errCircuitBreakerOpen } + var retryBuffer *skpio.CopyBodyStream + retryConfig, ok := ctx.StateBag()[filters.RetryName] + if ok { + retryBuffer = skpio.NewCopyBodyStream(int(ctx.Request().ContentLength), &bytes.Buffer{}, ctx.Request().Body) + ctx.Request().Body = retryBuffer + } + backendContext := ctx.request.Context() if timeout, ok := ctx.StateBag()[filters.BackendTimeout]; ok { + defer ctx.cancelBackendContext() backendContext, ctx.cancelBackendContext = stdlibcontext.WithTimeout(backendContext, timeout.(time.Duration)) } @@ -1254,6 +1262,24 @@ func (p *Proxy) do(ctx *context, parentSpan ot.Span) (err error) { p.applyFiltersOnError(ctx, processedFilters) return perr2 } + + } else if retryConfig != nil { + ctx.Logger().Infof("execute retry") + perr = nil + var perr2 *proxyError + + ctx.request.Body = retryBuffer.GetBody() + rsp, perr2 = p.makeBackendRequest(ctx, backendContext) + if perr2 != nil { + ctx.Logger().Errorf("Failed to retry backend request by filter: %v", perr2) + if perr2.code >= http.StatusInternalServerError { + p.metrics.MeasureBackend5xx(backendStart) + } + p.makeErrorResponse(ctx, perr2) + p.applyFiltersOnError(ctx, processedFilters) + return perr2 + } + } else { p.makeErrorResponse(ctx, perr) p.applyFiltersOnError(ctx, processedFilters) @@ -1263,6 +1289,23 @@ func (p *Proxy) do(ctx *context, parentSpan ot.Span) (err error) { if rsp.StatusCode >= http.StatusInternalServerError { p.metrics.MeasureBackend5xx(backendStart) + + if retryConfig != nil { + ctx.Logger().Infof("execute retry filter") + perr = nil + + ctx.request.Body = retryBuffer.GetBody() + rsp, perr = p.makeBackendRequest(ctx, backendContext) + if perr != nil { + ctx.Logger().Errorf("Failed to retry backend request by filter: %v", perr) + if perr.code >= http.StatusInternalServerError { + p.metrics.MeasureBackend5xx(backendStart) + } + p.makeErrorResponse(ctx, perr) + p.applyFiltersOnError(ctx, processedFilters) + return perr + } + } } if done != nil {