Skip to content

Commit 0fadf80

Browse files
authored
fix: sse request body issue on retry (#1126)
1 parent 8c57f2c commit 0fadf80

2 files changed

Lines changed: 108 additions & 30 deletions

File tree

sse.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ type (
7171
url string
7272
method string
7373
header http.Header
74-
body io.Reader
74+
bodyBytes []byte
7575
lastEventID string
7676
retryCount int
7777
retryWaitTime time.Duration
@@ -161,7 +161,21 @@ func (sse *SSESource) SetHeader(header, value string) *SSESource {
161161
// Example:
162162
// sse.SetBody(bytes.NewReader([]byte(`{"test":"put_data"}`)))
163163
func (sse *SSESource) SetBody(body io.Reader) *SSESource {
164-
sse.body = body
164+
sse.lock.Lock()
165+
defer sse.lock.Unlock()
166+
if body == nil {
167+
sse.bodyBytes = nil
168+
return sse
169+
}
170+
171+
sse.bodyBytes = nil
172+
bodyBytes, err := ioReadAll(body)
173+
if err != nil {
174+
sse.log.Errorf("resty:sse: unable to read body, error: %v", err)
175+
return sse
176+
}
177+
178+
sse.bodyBytes = bodyBytes
165179
return sse
166180
}
167181

@@ -530,7 +544,13 @@ func (sse *SSESource) triggerOnRequestFailure(err error, res *http.Response) {
530544
}
531545

532546
func (sse *SSESource) createRequest() (*http.Request, error) {
533-
req, err := http.NewRequest(sse.method, sse.url, sse.body)
547+
var reqBody io.Reader
548+
if sse.bodyBytes != nil {
549+
// create reader from bytes on each request
550+
reqBody = bytes.NewReader(sse.bodyBytes)
551+
}
552+
553+
req, err := http.NewRequest(sse.method, sse.url, reqBody)
534554
if err != nil {
535555
return nil, err
536556
}
@@ -571,6 +591,7 @@ func (sse *SSESource) connect() (*http.Response, error) {
571591

572592
resp, doErr := sse.httpClient.Do(req)
573593
if resp != nil && resp.StatusCode == http.StatusOK {
594+
// successful connection, return response to listenStream
574595
return resp, nil
575596
}
576597

sse_test.go

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import (
2020
"time"
2121
)
2222

23-
func TestEventSourceSimpleFlow(t *testing.T) {
24-
es := createEventSource(t, "", nil, nil)
23+
func TestSSESourceSimpleFlow(t *testing.T) {
24+
es := createSSESource(t, "", nil, nil)
2525

2626
messageCounter := 0
2727
messageFunc := func(e any) {
@@ -57,7 +57,7 @@ func TestEventSourceSimpleFlow(t *testing.T) {
5757
assertEqual(t, counter, messageCounter)
5858
}
5959

60-
func TestEventSourceMultipleEventTypes(t *testing.T) {
60+
func TestSSESourceMultipleEventTypes(t *testing.T) {
6161
type userEvent struct {
6262
UserName string `json:"username"`
6363
Message string `json:"msg"`
@@ -83,7 +83,7 @@ func TestEventSourceMultipleEventTypes(t *testing.T) {
8383
}
8484

8585
counter := 0
86-
es := createEventSource(t, "", func(any) {}, nil)
86+
es := createSSESource(t, "", func(any) {}, nil)
8787
ts := createSSETestServer(
8888
t,
8989
10*time.Millisecond,
@@ -130,11 +130,11 @@ func TestEventSourceMultipleEventTypes(t *testing.T) {
130130
assertEqual(t, userConnectCounter, userMessageCounter)
131131
}
132132

133-
func TestEventSourceOverwriteFuncs(t *testing.T) {
133+
func TestSSESourceOverwriteFuncs(t *testing.T) {
134134
messageFunc1 := func(e any) {
135135
assertNotNil(t, e)
136136
}
137-
es := createEventSource(t, "", messageFunc1, nil)
137+
es := createSSESource(t, "", messageFunc1, nil)
138138

139139
message2Counter := 0
140140
messageFunc2 := func(e any) {
@@ -184,8 +184,8 @@ func TestEventSourceOverwriteFuncs(t *testing.T) {
184184
assertTrue(t, strings.Contains(logLines, "Overwriting an existing OnError callback"))
185185
}
186186

187-
func TestEventSourceRetry(t *testing.T) {
188-
es := createEventSource(t, "", nil, nil)
187+
func TestSSESourceRetry(t *testing.T) {
188+
es := createSSESource(t, "", nil, nil)
189189

190190
messageCounter := 2 // 0 & 1 connection failure
191191
messageFunc := func(e any) {
@@ -265,10 +265,51 @@ func TestEventSourceRetry(t *testing.T) {
265265
assertNotNil(t, err2)
266266
}
267267

268-
func TestEventSourceTLSConfigerInterface(t *testing.T) {
268+
func TestSSESourceRetryReusesRequestBody(t *testing.T) {
269+
const payload = `{"test":"retry-body"}`
270+
271+
attempt := 0
272+
bodies := make([]string, 0, 2)
273+
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
274+
body, err := io.ReadAll(r.Body)
275+
assertNil(t, err)
276+
bodies = append(bodies, string(body))
277+
278+
attempt++
279+
if attempt == 1 {
280+
w.WriteHeader(http.StatusTooManyRequests)
281+
return
282+
}
283+
284+
w.WriteHeader(http.StatusOK)
285+
})
286+
defer ts.Close()
287+
288+
es := NewSSESource().
289+
SetURL(ts.URL).
290+
SetMethod(MethodPost).
291+
SetRetryCount(1).
292+
SetRetryWaitTime(1 * time.Millisecond).
293+
SetRetryMaxWaitTime(1 * time.Millisecond)
294+
es.SetBody(bytes.NewBufferString(payload))
295+
296+
resp, err := es.connect()
297+
assertNil(t, err)
298+
assertNotNil(t, resp)
299+
if resp != nil {
300+
closeq(resp.Body)
301+
}
302+
303+
assertEqual(t, 2, attempt, "expected one retry attempt")
304+
assertEqual(t, 2, len(bodies), "expected request body on both attempts")
305+
assertEqual(t, payload, bodies[0], "expected first attempt body to match")
306+
assertEqual(t, payload, bodies[1], "expected retry attempt body to match")
307+
}
308+
309+
func TestSSESourceTLSConfigerInterface(t *testing.T) {
269310

270311
t.Run("set and get tls config", func(t *testing.T) {
271-
es := createEventSource(t, "", func(any) {}, nil)
312+
es := createSSESource(t, "", func(any) {}, nil)
272313

273314
tc, err := es.tlsConfig()
274315
assertNil(t, err)
@@ -280,15 +321,15 @@ func TestEventSourceTLSConfigerInterface(t *testing.T) {
280321
})
281322

282323
t.Run("get tls config error", func(t *testing.T) {
283-
es := createEventSource(t, "", func(any) {}, nil)
324+
es := createSSESource(t, "", func(any) {}, nil)
284325

285326
ct := &CustomRoundTripper1{}
286327
es.httpClient.Transport = ct
287328
assertNil(t, es.TLSClientConfig())
288329
})
289330

290331
t.Run("set tls config", func(t *testing.T) {
291-
es := createEventSource(t, "", func(any) {}, nil)
332+
es := createSSESource(t, "", func(any) {}, nil)
292333

293334
ct := &CustomRoundTripper2{}
294335
es.httpClient.Transport = ct
@@ -299,7 +340,7 @@ func TestEventSourceTLSConfigerInterface(t *testing.T) {
299340
})
300341

301342
t.Run("set tls config error", func(t *testing.T) {
302-
es := createEventSource(t, "", func(any) {}, nil)
343+
es := createSSESource(t, "", func(any) {}, nil)
303344

304345
ct := &CustomRoundTripper2{returnErr: true}
305346
es.httpClient.Transport = ct
@@ -310,8 +351,8 @@ func TestEventSourceTLSConfigerInterface(t *testing.T) {
310351
})
311352
}
312353

313-
func TestEventSourceNoRetryRequired(t *testing.T) {
314-
es := createEventSource(t, "", func(any) {}, nil)
354+
func TestSSESourceNoRetryRequired(t *testing.T) {
355+
es := createSSESource(t, "", func(any) {}, nil)
315356
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
316357
w.WriteHeader(http.StatusBadRequest)
317358
})
@@ -343,7 +384,7 @@ func TestGH1044TrimHeader(t *testing.T) {
343384
}
344385

345386
func TestGH1041RequestFailureWithResponseBody(t *testing.T) {
346-
es := createEventSource(t, "", func(any) {}, nil)
387+
es := createSSESource(t, "", func(any) {}, nil)
347388
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
348389
w.Header().Set(hdrContentTypeKey, jsonContentType)
349390
w.WriteHeader(http.StatusBadRequest)
@@ -367,8 +408,8 @@ func TestGH1041RequestFailureWithResponseBody(t *testing.T) {
367408
assertEqual(t, "resty:sse: 400 Bad Request", err.Error())
368409
}
369410

370-
func TestEventSourceHTTPError(t *testing.T) {
371-
es := createEventSource(t, "", func(any) {}, nil)
411+
func TestSSESourceHTTPError(t *testing.T) {
412+
es := createSSESource(t, "", func(any) {}, nil)
372413
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
373414
http.Redirect(w, r, "http://local host", http.StatusTemporaryRedirect)
374415
})
@@ -379,10 +420,10 @@ func TestEventSourceHTTPError(t *testing.T) {
379420
assertTrue(t, strings.Contains(err.Error(), `invalid character " " in host name`))
380421
}
381422

382-
func TestEventSourceParseAndReadError(t *testing.T) {
423+
func TestSSESourceParseAndReadError(t *testing.T) {
383424
type data struct{}
384425
counter := 0
385-
es := createEventSource(t, "", func(any) {}, data{})
426+
es := createSSESource(t, "", func(any) {}, data{})
386427
ts := createSSETestServer(
387428
t,
388429
5*time.Millisecond,
@@ -415,8 +456,8 @@ func TestEventSourceParseAndReadError(t *testing.T) {
415456
})
416457
}
417458

418-
func TestEventSourceReadError(t *testing.T) {
419-
es := createEventSource(t, "", func(any) {}, nil)
459+
func TestSSESourceReadError(t *testing.T) {
460+
es := createSSESource(t, "", func(any) {}, nil)
420461
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
421462
w.WriteHeader(http.StatusOK)
422463
})
@@ -436,7 +477,7 @@ func TestEventSourceReadError(t *testing.T) {
436477
assertTrue(t, strings.Contains(err.Error(), "read event test error"))
437478
}
438479

439-
func TestEventSourceWithDifferentMethods(t *testing.T) {
480+
func TestSSESourceWithDifferentMethods(t *testing.T) {
440481
testCases := []struct {
441482
name string
442483
method string
@@ -471,7 +512,7 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
471512

472513
for _, tc := range testCases {
473514
t.Run(tc.name, func(t *testing.T) {
474-
es := createEventSource(t, "", nil, nil)
515+
es := createSSESource(t, "", nil, nil)
475516

476517
messageCounter := 0
477518
messageFunc := func(e any) {
@@ -530,7 +571,7 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
530571
}
531572
}
532573

533-
func TestEventSource_readEventFunc(t *testing.T) {
574+
func TestSSESource_readEventFunc(t *testing.T) {
534575
t.Run("successful scan", func(t *testing.T) {
535576
input := "event: test\ndata: test data\n\n"
536577
scanner := bufio.NewScanner(strings.NewReader(input))
@@ -589,7 +630,7 @@ func TestEventSource_readEventFunc(t *testing.T) {
589630
})
590631
}
591632

592-
func TestEventSourceCoverage(t *testing.T) {
633+
func TestSSESourceCoverage(t *testing.T) {
593634
es := NewSSESource()
594635
err1 := es.Get()
595636
assertEqual(t, "resty:sse: event source URL is required", err1.Error())
@@ -608,7 +649,23 @@ func TestEventSourceCoverage(t *testing.T) {
608649
parseEvent([]byte{})
609650
}
610651

611-
func createEventSource(t *testing.T, url string, fn SSEMessageFunc, rt any) *SSESource {
652+
func TestSSESetBody(t *testing.T) {
653+
t.Run("nil input", func(t *testing.T) {
654+
es := createSSESource(t, "", nil, nil)
655+
656+
es.SetBody(nil)
657+
assertNil(t, es.bodyBytes)
658+
})
659+
660+
t.Run("read error", func(t *testing.T) {
661+
es := createSSESource(t, "", nil, nil)
662+
663+
es.SetBody(&errorReader{})
664+
assertNil(t, es.bodyBytes)
665+
})
666+
}
667+
668+
func createSSESource(t *testing.T, url string, fn SSEMessageFunc, rt any) *SSESource {
612669
es := NewSSESource().
613670
SetURL(url).
614671
SetMethod(MethodGet).

0 commit comments

Comments
 (0)