Skip to content

Commit 1e624f9

Browse files
Add retries for transient error code
Signed-off-by: Tien Nguyen <[email protected]>
1 parent b69bf6c commit 1e624f9

File tree

5 files changed

+226
-38
lines changed

5 files changed

+226
-38
lines changed

.generator/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ files:
2222
api_user_test.go: {}
2323
cache_test.go: {}
2424
cache.go: {}
25+
client_test.go: {}
2526
configuration_test.go: {}
2627
gocache.go: {}
2728
main_test.go: {}

.generator/templates/client.mustache

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"path/filepath"
2727
"reflect"
2828
"regexp"
29+
"slices"
2930
"strconv"
3031
"strings"
3132
"time"
@@ -1204,13 +1205,16 @@ func (c *APIClient) doWithRetries(ctx context.Context, req *http.Request) (*http
12041205
// this is error is considered to be permanent and should not be retried
12051206
return backoff.Permanent(err)
12061207
}
1207-
if !tooManyRequests(resp) {
1208+
if !shouldRetryRequests(resp) {
12081209
return nil
12091210
}
1210-
if err = tryDrainBody(resp.Body); err != nil {
1211-
return err
1211+
// tryDrainBody except the last one
1212+
if bOff.maxRetries != bOff.retryCount {
1213+
if err = tryDrainBody(resp.Body); err != nil {
1214+
return err
1215+
}
12121216
}
1213-
backoffDuration, err := Get429BackoffTime(resp)
1217+
backoffDuration, err := GetBackoffTime(resp, c.cfg)
12141218
if err != nil {
12151219
return err
12161220
}
@@ -1221,7 +1225,7 @@ func (c *APIClient) doWithRetries(ctx context.Context, req *http.Request) (*http
12211225
bOff.retryCount++
12221226
req.Header.Add("X-Okta-Retry-For", resp.Header.Get("X-Okta-Request-Id"))
12231227
req.Header.Add("X-Okta-Retry-Count", fmt.Sprint(bOff.retryCount))
1224-
return errors.New("too many requests")
1228+
return fmt.Errorf("unexpected status code %v, giving up after %v retries", resp.StatusCode, bOff.maxRetries)
12251229
}
12261230
err = backoff.Retry(operation, bOff)
12271231
return resp, err
@@ -1410,28 +1414,53 @@ func (o *oktaBackoff) Context() context.Context {
14101414
return o.ctx
14111415
}
14121416

1413-
func tooManyRequests(resp *http.Response) bool {
1414-
return resp != nil && resp.StatusCode == http.StatusTooManyRequests
1417+
var retryStatusCodes = []int{
1418+
http.StatusTooManyRequests,
1419+
http.StatusInternalServerError,
1420+
http.StatusBadGateway,
1421+
http.StatusServiceUnavailable,
1422+
http.StatusGatewayTimeout,
14151423
}
14161424

1425+
var defaultBackOffTimeInSeconds int64 = 30
1426+
1427+
func shouldRetryRequests(resp *http.Response) bool {
1428+
if resp == nil {
1429+
return false
1430+
}
1431+
return slices.Contains(retryStatusCodes, resp.StatusCode)
1432+
}
14171433
func tryDrainBody(body io.ReadCloser) error {
14181434
defer body.Close()
14191435
_, err := io.Copy(ioutil.Discard, io.LimitReader(body, 4096))
14201436
return err
14211437
}
14221438

1423-
func Get429BackoffTime(resp *http.Response) (int64, error) {
1424-
requestDate, err := time.Parse("Mon, 02 Jan 2006 15:04:05 GMT", resp.Header.Get("Date"))
1425-
if err != nil {
1426-
// this is error is considered to be permanent and should not be retried
1427-
return 0, backoff.Permanent(fmt.Errorf("date header is missing or invalid: %w", err))
1428-
}
1429-
rateLimitReset, err := strconv.Atoi(resp.Header.Get("X-Rate-Limit-Reset"))
1430-
if err != nil {
1431-
// this is error is considered to be permanent and should not be retried
1432-
return 0, backoff.Permanent(fmt.Errorf("X-Rate-Limit-Reset header is missing or invalid: %w", err))
1439+
func GetBackoffTime(resp *http.Response, config *Configuration) (int64, error) {
1440+
var backoffTimeInSeconds int64
1441+
var backoffErr error
1442+
switch resp.StatusCode {
1443+
case http.StatusTooManyRequests:
1444+
requestDate, err := time.Parse("Mon, 02 Jan 2006 15:04:05 GMT", resp.Header.Get("Date"))
1445+
if err != nil {
1446+
// this is error is considered to be permanent and should not be retried
1447+
backoffTimeInSeconds = 0
1448+
backoffErr = backoff.Permanent(fmt.Errorf("date header is missing or invalid: %w", err))
1449+
}
1450+
rateLimitReset, err := strconv.Atoi(resp.Header.Get("X-Rate-Limit-Reset"))
1451+
if err != nil {
1452+
// this is error is considered to be permanent and should not be retried
1453+
backoffTimeInSeconds = 0
1454+
backoffErr = backoff.Permanent(fmt.Errorf("X-Rate-Limit-Reset header is missing or invalid: %w", err))
1455+
}
1456+
backoffTimeInSeconds = int64(rateLimitReset) - requestDate.Unix() + 1
1457+
default:
1458+
backoffTimeInSeconds = config.Okta.Client.RateLimit.MaxBackoff
1459+
if backoffTimeInSeconds == 0 {
1460+
backoffTimeInSeconds = defaultBackOffTimeInSeconds
1461+
}
14331462
}
1434-
return int64(rateLimitReset) - requestDate.Unix() + 1, nil
1463+
return backoffTimeInSeconds, backoffErr
14351464
}
14361465

14371466
type ClientAssertionClaims struct {
@@ -1543,7 +1572,7 @@ func (c *APIClient) parseLimitHeaders(resp *http.Response) (*RateLimit, error) {
15431572
if err != nil {
15441573
return nil, err
15451574
}
1546-
reset, err := Get429BackoffTime(resp)
1575+
reset, err := GetBackoffTime(resp, c.cfg)
15471576
if err != nil {
15481577
return nil, err
15491578
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package okta
2+
3+
import (
4+
"io"
5+
"net/http"
6+
"strings"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
type mockRoundTripper struct {
14+
Responses []http.Response
15+
CallCount int
16+
}
17+
18+
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
19+
// simulate responses for each call
20+
if m.CallCount < len(m.Responses) {
21+
resp := m.Responses[m.CallCount]
22+
m.CallCount++
23+
return &resp, nil
24+
}
25+
return &http.Response{StatusCode: http.StatusInternalServerError}, nil
26+
}
27+
28+
func TestCreateRetryableHTTPClient_RetryOn500(t *testing.T) {
29+
bodyString := "this is a test body"
30+
body := io.NopCloser(strings.NewReader(bodyString))
31+
mockResponses := []http.Response{
32+
{StatusCode: 500}, // should retry
33+
{StatusCode: 500}, // should retry
34+
{StatusCode: 500}, // should retry
35+
{StatusCode: 500}, // should retry
36+
{StatusCode: 500, Body: body}, // should retry
37+
}
38+
39+
mockTransport := &mockRoundTripper{Responses: mockResponses}
40+
41+
// base client with the mock transport
42+
baseClient := &http.Client{
43+
Transport: mockTransport,
44+
Timeout: 10 * time.Second,
45+
}
46+
47+
configuration, err := NewConfiguration(WithAuthorizationMode("Bearer"), WithHttpClientPtr(baseClient), WithRateLimitMaxBackOff(1), WithRateLimitMaxRetries(4))
48+
require.NoError(t, err, "Creating a new config should not error")
49+
client := NewAPIClient(configuration)
50+
51+
req, err := http.NewRequest("GET", "http://foo.com", nil)
52+
require.NoError(t, err)
53+
54+
resp, err := client.doWithRetries(client.cfg.Context, req)
55+
56+
require.Error(t, err)
57+
require.Equal(t, "unexpected status code 500, giving up after 4 retries", err.Error())
58+
require.True(t, resp != nil)
59+
require.Equal(t, resp.StatusCode, 500)
60+
message, err := io.ReadAll(resp.Body)
61+
require.NoError(t, err)
62+
require.Equal(t, bodyString, string(message))
63+
require.Equal(t, 5, mockTransport.CallCount)
64+
}

okta/client.go

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import (
4848
"path/filepath"
4949
"reflect"
5050
"regexp"
51+
"slices"
5152
"strconv"
5253
"strings"
5354
"sync"
@@ -1392,13 +1393,16 @@ func (c *APIClient) doWithRetries(ctx context.Context, req *http.Request) (*http
13921393
// this is error is considered to be permanent and should not be retried
13931394
return backoff.Permanent(err)
13941395
}
1395-
if !tooManyRequests(resp) {
1396+
if !shouldRetryRequests(resp) {
13961397
return nil
13971398
}
1398-
if err = tryDrainBody(resp.Body); err != nil {
1399-
return err
1399+
// tryDrainBody except the last one
1400+
if bOff.maxRetries != bOff.retryCount {
1401+
if err = tryDrainBody(resp.Body); err != nil {
1402+
return err
1403+
}
14001404
}
1401-
backoffDuration, err := Get429BackoffTime(resp)
1405+
backoffDuration, err := GetBackoffTime(resp, c.cfg)
14021406
if err != nil {
14031407
return err
14041408
}
@@ -1409,7 +1413,7 @@ func (c *APIClient) doWithRetries(ctx context.Context, req *http.Request) (*http
14091413
bOff.retryCount++
14101414
req.Header.Add("X-Okta-Retry-For", resp.Header.Get("X-Okta-Request-Id"))
14111415
req.Header.Add("X-Okta-Retry-Count", fmt.Sprint(bOff.retryCount))
1412-
return errors.New("too many requests")
1416+
return fmt.Errorf("unexpected status code %v, giving up after %v retries", resp.StatusCode, bOff.maxRetries)
14131417
}
14141418
err = backoff.Retry(operation, bOff)
14151419
return resp, err
@@ -1598,8 +1602,21 @@ func (o *oktaBackoff) Context() context.Context {
15981602
return o.ctx
15991603
}
16001604

1601-
func tooManyRequests(resp *http.Response) bool {
1602-
return resp != nil && resp.StatusCode == http.StatusTooManyRequests
1605+
var retryStatusCodes = []int{
1606+
http.StatusTooManyRequests,
1607+
http.StatusInternalServerError,
1608+
http.StatusBadGateway,
1609+
http.StatusServiceUnavailable,
1610+
http.StatusGatewayTimeout,
1611+
}
1612+
1613+
var defaultBackOffTimeInSeconds int64 = 30
1614+
1615+
func shouldRetryRequests(resp *http.Response) bool {
1616+
if resp == nil {
1617+
return false
1618+
}
1619+
return slices.Contains(retryStatusCodes, resp.StatusCode)
16031620
}
16041621

16051622
func tryDrainBody(body io.ReadCloser) error {
@@ -1608,18 +1625,31 @@ func tryDrainBody(body io.ReadCloser) error {
16081625
return err
16091626
}
16101627

1611-
func Get429BackoffTime(resp *http.Response) (int64, error) {
1612-
requestDate, err := time.Parse("Mon, 02 Jan 2006 15:04:05 GMT", resp.Header.Get("Date"))
1613-
if err != nil {
1614-
// this is error is considered to be permanent and should not be retried
1615-
return 0, backoff.Permanent(fmt.Errorf("date header is missing or invalid: %w", err))
1616-
}
1617-
rateLimitReset, err := strconv.Atoi(resp.Header.Get("X-Rate-Limit-Reset"))
1618-
if err != nil {
1619-
// this is error is considered to be permanent and should not be retried
1620-
return 0, backoff.Permanent(fmt.Errorf("X-Rate-Limit-Reset header is missing or invalid: %w", err))
1628+
func GetBackoffTime(resp *http.Response, config *Configuration) (int64, error) {
1629+
var backoffTimeInSeconds int64
1630+
var backoffErr error
1631+
switch resp.StatusCode {
1632+
case http.StatusTooManyRequests:
1633+
requestDate, err := time.Parse("Mon, 02 Jan 2006 15:04:05 GMT", resp.Header.Get("Date"))
1634+
if err != nil {
1635+
// this is error is considered to be permanent and should not be retried
1636+
backoffTimeInSeconds = 0
1637+
backoffErr = backoff.Permanent(fmt.Errorf("date header is missing or invalid: %w", err))
1638+
}
1639+
rateLimitReset, err := strconv.Atoi(resp.Header.Get("X-Rate-Limit-Reset"))
1640+
if err != nil {
1641+
// this is error is considered to be permanent and should not be retried
1642+
backoffTimeInSeconds = 0
1643+
backoffErr = backoff.Permanent(fmt.Errorf("X-Rate-Limit-Reset header is missing or invalid: %w", err))
1644+
}
1645+
backoffTimeInSeconds = int64(rateLimitReset) - requestDate.Unix() + 1
1646+
default:
1647+
backoffTimeInSeconds = config.Okta.Client.RateLimit.MaxBackoff
1648+
if backoffTimeInSeconds == 0 {
1649+
backoffTimeInSeconds = defaultBackOffTimeInSeconds
1650+
}
16211651
}
1622-
return int64(rateLimitReset) - requestDate.Unix() + 1, nil
1652+
return backoffTimeInSeconds, backoffErr
16231653
}
16241654

16251655
type ClientAssertionClaims struct {
@@ -1731,7 +1761,7 @@ func (c *APIClient) parseLimitHeaders(resp *http.Response) (*RateLimit, error) {
17311761
if err != nil {
17321762
return nil, err
17331763
}
1734-
reset, err := Get429BackoffTime(resp)
1764+
reset, err := GetBackoffTime(resp, c.cfg)
17351765
if err != nil {
17361766
return nil, err
17371767
}

okta/client_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package okta
2+
3+
import (
4+
"io"
5+
"net/http"
6+
"strings"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
type mockRoundTripper struct {
14+
Responses []http.Response
15+
CallCount int
16+
}
17+
18+
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
19+
// simulate responses for each call
20+
if m.CallCount < len(m.Responses) {
21+
resp := m.Responses[m.CallCount]
22+
m.CallCount++
23+
return &resp, nil
24+
}
25+
return &http.Response{StatusCode: http.StatusInternalServerError}, nil
26+
}
27+
28+
func TestCreateRetryableHTTPClient_RetryOn500(t *testing.T) {
29+
bodyString := "this is a test body"
30+
body := io.NopCloser(strings.NewReader(bodyString))
31+
mockResponses := []http.Response{
32+
{StatusCode: 500}, // should retry
33+
{StatusCode: 500}, // should retry
34+
{StatusCode: 500}, // should retry
35+
{StatusCode: 500}, // should retry
36+
{StatusCode: 500, Body: body}, // should retry
37+
}
38+
39+
mockTransport := &mockRoundTripper{Responses: mockResponses}
40+
41+
// base client with the mock transport
42+
baseClient := &http.Client{
43+
Transport: mockTransport,
44+
Timeout: 10 * time.Second,
45+
}
46+
47+
configuration, err := NewConfiguration(WithAuthorizationMode("Bearer"), WithHttpClientPtr(baseClient), WithRateLimitMaxBackOff(1), WithRateLimitMaxRetries(4))
48+
require.NoError(t, err, "Creating a new config should not error")
49+
client := NewAPIClient(configuration)
50+
51+
req, err := http.NewRequest("GET", "http://foo.com", nil)
52+
require.NoError(t, err)
53+
54+
resp, err := client.doWithRetries(client.cfg.Context, req)
55+
56+
require.Error(t, err)
57+
require.Equal(t, "unexpected status code 500, giving up after 4 retries", err.Error())
58+
require.True(t, resp != nil)
59+
require.Equal(t, resp.StatusCode, 500)
60+
message, err := io.ReadAll(resp.Body)
61+
require.NoError(t, err)
62+
require.Equal(t, bodyString, string(message))
63+
require.Equal(t, 5, mockTransport.CallCount)
64+
}

0 commit comments

Comments
 (0)