Skip to content

Commit

Permalink
feat(container): implement generic Retry-After header handling (#5867)
Browse files Browse the repository at this point in the history
Signed-off-by: Aleksei Igrychev <[email protected]>
  • Loading branch information
alexey-igrychev committed Dec 2, 2023
1 parent 0dc4ed0 commit b2a6022
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 85 deletions.
47 changes: 9 additions & 38 deletions pkg/docker_registry/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package docker_registry

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"regexp"
"strings"
"time"
Expand Down Expand Up @@ -115,7 +113,7 @@ func (api *api) TryGetRepoImage(ctx context.Context, reference string) (*image.I
}

func (api *api) GetRepoImage(ctx context.Context, reference string) (*image.Info, error) {
desc, _, err := api.getImageDesc(reference)
desc, _, err := api.getImageDesc(ctx, reference)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -161,7 +159,7 @@ func (api *api) getRepoImageByDesc(ctx context.Context, originalTag string, desc

for _, desc := range im.Manifests {
subref := fmt.Sprintf("%s@%s", repoImage.Repository, desc.Digest)
subdesc, _, err := api.getImageDesc(subref)
subdesc, _, err := api.getImageDesc(ctx, subref)
if err != nil {
return nil, fmt.Errorf("error getting image %s manifest: %w", subref, err)
}
Expand Down Expand Up @@ -377,7 +375,7 @@ func (api *api) MutateAndPushImage(ctx context.Context, sourceReference, destina
}
_, isDstDigest := dstRef.(name.Digest)

desc, _, err := api.getImageDesc(sourceReference)
desc, _, err := api.getImageDesc(ctx, sourceReference)
if err != nil {
return fmt.Errorf("error reading image %q: %w", sourceReference, err)
}
Expand Down Expand Up @@ -406,7 +404,7 @@ func (api *api) CopyImage(ctx context.Context, sourceReference, destinationRefer
if err != nil {
return fmt.Errorf("parsing reference %q: %w", destinationReference, err)
}
desc, _, err := api.getImageDesc(sourceReference)
desc, _, err := api.getImageDesc(ctx, sourceReference)
if err != nil {
return fmt.Errorf("unable to get image %s: %w", sourceReference, err)
}
Expand Down Expand Up @@ -472,17 +470,13 @@ func (api *api) pushImage(ctx context.Context, reference string, opts *PushImage
return nil
}

func (api *api) getImageDesc(reference string) (*remote.Descriptor, name.Reference, error) {
func (api *api) getImageDesc(ctx context.Context, reference string) (*remote.Descriptor, name.Reference, error) {
ref, err := name.ParseReference(reference, api.parseReferenceOptions()...)
if err != nil {
return nil, nil, fmt.Errorf("parsing reference %q: %w", reference, err)
}

desc, err := remote.Get(
ref,
remote.WithAuthFromKeychain(authn.DefaultKeychain),
remote.WithTransport(api.getHttpTransport()),
)
desc, err := remote.Get(ref, api.defaultRemoteOptions(ctx)...)
if err != nil {
return nil, nil, fmt.Errorf("getting %s: %w", ref, err)
}
Expand All @@ -508,33 +502,10 @@ func (api *api) defaultRemoteOptions(ctx context.Context) []remote.Option {
return []remote.Option{
remote.WithContext(ctx),
remote.WithAuthFromKeychain(authn.DefaultKeychain),
remote.WithTransport(api.getHttpTransport()),
remote.WithTransport(getHttpTransport(api.SkipTlsVerifyRegistry)),
}
}

func (api *api) getHttpTransport() (transport http.RoundTripper) {
transport = http.DefaultTransport

if api.SkipTlsVerifyRegistry {
defaultTransport := http.DefaultTransport.(*http.Transport)

newTransport := &http.Transport{
Proxy: defaultTransport.Proxy,
DialContext: defaultTransport.DialContext,
MaxIdleConns: defaultTransport.MaxIdleConns,
IdleConnTimeout: defaultTransport.IdleConnTimeout,
TLSHandshakeTimeout: defaultTransport.TLSHandshakeTimeout,
ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
}

transport = newTransport
}

return
}

type referenceParts struct {
registry string
repository string
Expand Down Expand Up @@ -631,15 +602,15 @@ func (api *api) writeToRemote(ctx context.Context, ref name.Reference, imageOrIn
ref, i,
remote.WithAuthFromKeychain(authn.DefaultKeychain),
remote.WithProgress(c),
remote.WithTransport(api.getHttpTransport()),
remote.WithTransport(getHttpTransport(api.SkipTlsVerifyRegistry)),
remote.WithContext(ctx),
)
case v1.ImageIndex:
go remote.WriteIndex(
ref, i,
remote.WithAuthFromKeychain(authn.DefaultKeychain),
remote.WithProgress(c),
remote.WithTransport(api.getHttpTransport()),
remote.WithTransport(getHttpTransport(api.SkipTlsVerifyRegistry)),
remote.WithContext(ctx),
)
default:
Expand Down
23 changes: 22 additions & 1 deletion pkg/docker_registry/common_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package docker_registry

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"

"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/remote/transport"

"github.com/werf/logboek"
transport2 "github.com/werf/werf/pkg/docker_registry/transport"
)

type apiError struct {
Expand All @@ -22,6 +25,7 @@ type doRequestOptions struct {
Headers map[string]string
BasicAuth doRequestBasicAuth
AcceptedCodes []int
SkipTlsVerify bool
}

type doRequestBasicAuth struct {
Expand All @@ -44,7 +48,7 @@ func doRequest(ctx context.Context, method, url string, body io.Reader, options
}

logboek.Context(ctx).Debug().LogF("--> %s %s\n", method, url)
resp, err := http.DefaultClient.Do(req)
resp, err := getHTTPClient(options.SkipTlsVerify).Do(req)
if err != nil {
return nil, nil, err
}
Expand All @@ -67,3 +71,20 @@ func doRequest(ctx context.Context, method, url string, body io.Reader, options

return resp, respBody, nil
}

func getHTTPClient(skipTlsVerify bool) *http.Client {
return &http.Client{
Transport: getHttpTransport(skipTlsVerify),
}
}

func getHttpTransport(skipTlsVerify bool) http.RoundTripper {
t := remote.DefaultTransport.(*http.Transport).Clone()

if skipTlsVerify {
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
t.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
}

return transport.NewRetry(transport2.NewRetryAfter(t))
}
4 changes: 2 additions & 2 deletions pkg/docker_registry/generic_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func (api *genericApi) GetRepoImageConfigFile(ctx context.Context, reference str
return api.getRepoImageConfigFile(ctx, reference)
}

func (api *genericApi) getRepoImageConfigFile(_ context.Context, reference string) (*v1.ConfigFile, error) {
desc, _, err := api.commonApi.getImageDesc(reference)
func (api *genericApi) getRepoImageConfigFile(ctx context.Context, reference string) (*v1.ConfigFile, error) {
desc, _, err := api.commonApi.getImageDesc(ctx, reference)
if err != nil {
return nil, err
}
Expand Down
46 changes: 3 additions & 43 deletions pkg/docker_registry/github_packages_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"time"

"github.com/werf/logboek"
parallelConstant "github.com/werf/werf/pkg/util/parallel/constant"
)

type gitHubApi struct{}
Expand Down Expand Up @@ -63,7 +57,7 @@ type githubApiUser struct {

func (api *gitHubApi) getUser(ctx context.Context, username, token string) (githubApiUser, *http.Response, error) {
url := fmt.Sprintf("https://api.github.com/users/%s", username)
resp, respBody, err := api.doRequest(ctx, http.MethodGet, url, nil, doRequestOptions{
resp, respBody, err := doRequest(ctx, http.MethodGet, url, nil, doRequestOptions{
Headers: map[string]string{
"Accept": "application/vnd.github.v3+json",
"Authorization": fmt.Sprintf("Bearer %s", token),
Expand Down Expand Up @@ -141,7 +135,7 @@ type githubApiVersion struct {
func (api *gitHubApi) getContainerPackageVersionListInBatches(ctx context.Context, url, token string, f func([]githubApiVersion) error) (*http.Response, error) {
for page := 1; true; page++ {
pageUrl := url + fmt.Sprintf("?page=%d&per_page=100", page)
resp, respBody, err := api.doRequest(ctx, http.MethodGet, pageUrl, nil, doRequestOptions{
resp, respBody, err := doRequest(ctx, http.MethodGet, pageUrl, nil, doRequestOptions{
Headers: map[string]string{
"Accept": "application/vnd.github.v3+json",
"Authorization": fmt.Sprintf("Bearer %s", token),
Expand Down Expand Up @@ -170,7 +164,7 @@ func (api *gitHubApi) getContainerPackageVersionListInBatches(ctx context.Contex
}

func (api *gitHubApi) deleteContainerPackage(ctx context.Context, url, token string) (*http.Response, error) {
resp, _, err := api.doRequest(ctx, http.MethodDelete, url, nil, doRequestOptions{
resp, _, err := doRequest(ctx, http.MethodDelete, url, nil, doRequestOptions{
Headers: map[string]string{
"Accept": "application/vnd.github.v3+json",
"Authorization": fmt.Sprintf("Bearer %s", token),
Expand All @@ -183,37 +177,3 @@ func (api *gitHubApi) deleteContainerPackage(ctx context.Context, url, token str

return nil, nil
}

func (api *gitHubApi) doRequest(ctx context.Context, method, url string, body io.Reader, options doRequestOptions) (*http.Response, []byte, error) {
resp, respBody, err := doRequest(ctx, method, url, body, options)
if err != nil {
if resp != nil && resp.Header.Get("Retry-After") != "" {
secondsString := resp.Header.Get("Retry-After")
seconds, err := strconv.Atoi(secondsString)
if err == nil {
sleepSeconds := seconds + rand.Intn(15) + 5
workerId := ctx.Value(parallelConstant.CtxBackgroundTaskIDKey)
if workerId != nil {
logboek.Context(ctx).Warn().LogF(
"WARNING: Rate limit error occurred. Waiting for %d before retrying request... (worker %d).\nThe --parallel ($WERF_PARALLEL) and --parallel-tasks-limit ($WERF_PARALLEL_TASKS_LIMIT) options can be used to regulate parallel tasks.\n",
sleepSeconds,
workerId.(int),
)
logboek.Context(ctx).Warn().LogLn()
} else {
logboek.Context(ctx).Warn().LogF(
"WARNING: Rate limit error occurred. Waiting for %d before retrying request...\n",
sleepSeconds,
)
}

time.Sleep(time.Second * time.Duration(sleepSeconds))
return api.doRequest(ctx, method, url, body, options)
}
}

return resp, respBody, err
}

return resp, respBody, nil
}
2 changes: 1 addition & 1 deletion pkg/docker_registry/gitlab_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (r *gitLabRegistry) customDeleteRepoImage(endpointFormat, reference string,
}

scope := scopeFunc(ref)
tr, err := transport.New(ref.Context().Registry, auth, r.api.getHttpTransport(), scope)
tr, err := transport.New(ref.Context().Registry, auth, getHttpTransport(false), scope)
if err != nil {
return err
}
Expand Down
53 changes: 53 additions & 0 deletions pkg/docker_registry/transport/retry_after.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package transport

import (
"math/rand"
"net/http"
"strconv"
"time"

"github.com/werf/logboek"
parallelConstant "github.com/werf/werf/pkg/util/parallel/constant"
)

type RetryAfter struct {
underlying http.RoundTripper
}

func NewRetryAfter(underlying http.RoundTripper) http.RoundTripper {
return &RetryAfter{underlying: underlying}
}

func (t *RetryAfter) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.underlying.RoundTrip(req)
if err != nil {
return nil, err
}

if retryAfterHeader := resp.Header.Get("Retry-After"); retryAfterHeader != "" {
if seconds, err := strconv.Atoi(retryAfterHeader); err == nil {
sleepSeconds := rand.Intn(15) + 5

ctx := req.Context()
workerId := ctx.Value(parallelConstant.CtxBackgroundTaskIDKey)
if workerId != nil {
logboek.Context(ctx).Warn().LogF(
"WARNING: Rate limit error occurred. Waiting for %d before retrying request... (worker %d).\nThe --parallel ($WERF_PARALLEL) and --parallel-tasks-limit ($WERF_PARALLEL_TASKS_LIMIT) options can be used to regulate parallel tasks.\n",
sleepSeconds,
workerId.(int),
)
logboek.Context(ctx).Warn().LogLn()
} else {
logboek.Context(ctx).Warn().LogF(
"WARNING: Rate limit error occurred. Waiting for %d before retrying request...\n",
sleepSeconds,
)
}

time.Sleep(time.Second * time.Duration(seconds))
return t.RoundTrip(req)
}
}

return resp, nil
}

0 comments on commit b2a6022

Please sign in to comment.