Skip to content

Commit dd9cfa3

Browse files
committed
fix: improve error handling in the multipart request failure #1030
1 parent 09ed804 commit dd9cfa3

File tree

6 files changed

+158
-43
lines changed

6 files changed

+158
-43
lines changed

client.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
23802380
}
23812381
if req.multipartErrChan != nil {
23822382
if err = <-req.multipartErrChan; err != nil {
2383-
return response, err
2383+
response.CascadeError = wrapErrors(err, response.CascadeError)
23842384
}
23852385
}
23862386
if resp != nil {
@@ -2390,18 +2390,18 @@ func (c *Client) execute(req *Request) (*Response, error) {
23902390

23912391
response.Body = resp.Body
23922392
if err = response.wrapContentDecompresser(); err != nil {
2393-
return response, err
2393+
return response, response.wrapError(err, false)
23942394
}
23952395

23962396
response.wrapLimitReadCloser()
2397-
}
23982397

2399-
if !req.DoNotParseResponse {
2400-
if req.ResponseBodyUnlimitedReads || req.Debug {
2401-
response.wrapCopyReadCloser()
2398+
if !req.DoNotParseResponse {
2399+
if req.ResponseBodyUnlimitedReads || req.Debug {
2400+
response.wrapCopyReadCloser()
24022401

2403-
if err = response.readAll(); err != nil {
2404-
return response, err
2402+
if err = response.readAll(); err != nil {
2403+
return response, response.wrapError(err, false)
2404+
}
24052405
}
24062406
}
24072407
}
@@ -2415,8 +2415,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
24152415
}
24162416
}
24172417

2418-
err = response.CascadeError
2419-
return response, err
2418+
return response, response.wrapError(nil, false)
24202419
}
24212420

24222421
// getting TLS client config if not exists then create one

middleware.go

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,38 @@ func handleMultipart(c *Client, r *Request) error {
314314
}
315315

316316
r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
317-
closeq(mw)
318317

319-
return nil
318+
return mw.Close()
319+
}
320+
321+
// pre-process multipart fields to catch possible errors
322+
for _, mf := range r.multipartFields {
323+
if len(mf.Values) > 0 {
324+
continue
325+
}
326+
if err := mf.openFileIfRequired(); err != nil {
327+
return err
328+
}
329+
330+
// probe the file to catch possible errors
331+
// and also detect content type if empty
332+
p := make([]byte, 512)
333+
size, err := mf.Reader.Read(p)
334+
if err != nil && err != io.EOF {
335+
return err
336+
}
337+
mf.tempBuf = p[:size]
338+
339+
// auto detect content type if empty
340+
if isStringEmpty(mf.ContentType) {
341+
mf.ContentType = http.DetectContentType(mf.tempBuf)
342+
}
320343
}
321344

322345
// multipart streaming
323346
bodyReader, bodyWriter := io.Pipe()
324347
mw := multipart.NewWriter(bodyWriter)
325348
r.Body = bodyReader
326-
r.multipartErrChan = make(chan error, 1)
327349

328350
// set boundary if it is provided by the user
329351
if !isStringEmpty(r.multipartBoundary) {
@@ -332,13 +354,21 @@ func handleMultipart(c *Client, r *Request) error {
332354
}
333355
}
334356

357+
r.multipartErrChan = make(chan error, 1)
335358
go func() {
336-
defer close(r.multipartErrChan)
359+
defer func() {
360+
if err := mw.Close(); err != nil {
361+
r.multipartErrChan <- err
362+
}
363+
if err := bodyWriter.Close(); err != nil {
364+
r.multipartErrChan <- err
365+
}
366+
close(r.multipartErrChan)
367+
}()
368+
337369
if err := createMultipart(mw, r); err != nil {
338370
r.multipartErrChan <- err
339371
}
340-
closeq(mw)
341-
closeq(bodyWriter)
342372
}()
343373

344374
r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
@@ -362,27 +392,15 @@ func createMultipart(w *multipart.Writer, r *Request) error {
362392
continue
363393
}
364394

365-
if err := mf.openFileIfRequired(); err != nil {
366-
return err
367-
}
368-
369-
p := make([]byte, 512)
370-
size, err := mf.Reader.Read(p)
371-
if err != nil && err != io.EOF {
372-
return err
373-
}
374-
// auto detect content type if empty
375-
if isStringEmpty(mf.ContentType) {
376-
mf.ContentType = http.DetectContentType(p[:size])
377-
}
378-
379395
partWriter, err := mpCreatePart(w, mf.createHeader())
380396
if err != nil {
381397
return err
382398
}
383399

384400
partWriter = mf.wrapProgressCallbackIfPresent(partWriter)
385-
partWriter.Write(p[:size])
401+
if len(mf.tempBuf) > 0 {
402+
partWriter.Write(mf.tempBuf)
403+
}
386404

387405
if _, err = ioCopy(partWriter, mf.Reader); err != nil {
388406
return err
@@ -482,7 +500,8 @@ func handleRequestBody(c *Client, r *Request) error {
482500
// based on registered HTTP response `Content-Type` decoder, see [Client.AddContentTypeDecoder];
483501
// if [Request.SetResult], [Request.SetResultError], or [Client.SetResultError] is used
484502
func AutoParseResponseMiddleware(c *Client, res *Response) (err error) {
485-
if res.CascadeError != nil || res.Request.DoNotParseResponse {
503+
if (res.CascadeError != nil && (res.Request.isMultiPart && res.StatusCode() == 0)) ||
504+
res.Request.DoNotParseResponse {
486505
return // move on
487506
}
488507

multipart.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ type MultipartField struct {
5858
//
5959
// It is primarily added for ordered multipart form-data field use cases
6060
Values []string
61+
62+
// tempBuf is used to preserve the byte(s) read from the file to detect the content type.
63+
// Or any possible read error early.
64+
tempBuf []byte
6165
}
6266

6367
// Clone method returns the deep copy of m except [io.Reader].

multipart_test.go

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"path/filepath"
2020
"strconv"
2121
"strings"
22+
"sync"
2223
"testing"
2324
"time"
2425
)
@@ -88,8 +89,8 @@ func TestMultipartUploadError(t *testing.T) {
8889
Post(ts.URL + "/upload")
8990

9091
assertNotNil(t, err)
91-
assertNotNil(t, resp)
92-
assertTrue(t, errors.Is(err, fs.ErrNotExist))
92+
assertNil(t, resp)
93+
assertEqual(t, true, errors.Is(err, fs.ErrNotExist))
9394
}
9495

9596
func TestMultipartUploadFiles(t *testing.T) {
@@ -572,8 +573,7 @@ func TestMultipartReaderErrors(t *testing.T) {
572573

573574
assertNotNil(t, err)
574575
assertEqual(t, errTestErrorReader, err)
575-
assertNotNil(t, resp)
576-
assertEqual(t, nil, resp.Body)
576+
assertNil(t, resp)
577577
})
578578

579579
t.Run("multipart files with errorReader", func(t *testing.T) {
@@ -583,8 +583,7 @@ func TestMultipartReaderErrors(t *testing.T) {
583583

584584
assertNotNil(t, err)
585585
assertEqual(t, errTestErrorReader, err)
586-
assertNotNil(t, resp)
587-
assertEqual(t, nil, resp.Body)
586+
assertNil(t, resp)
588587
})
589588

590589
t.Run("multipart with file not found", func(t *testing.T) {
@@ -593,9 +592,8 @@ func TestMultipartReaderErrors(t *testing.T) {
593592
Post("/upload")
594593

595594
assertNotNil(t, err)
596-
assertTrue(t, errors.Is(err, fs.ErrNotExist))
597-
assertNotNil(t, resp)
598-
assertEqual(t, nil, resp.Body)
595+
assertEqual(t, true, errors.Is(err, fs.ErrNotExist))
596+
assertNil(t, resp)
599597
})
600598
}
601599

@@ -668,6 +666,89 @@ func TestMultipartRequest_createMultipart(t *testing.T) {
668666
})
669667
}
670668

669+
func TestMultipartUploadFailAutoErrorParse(t *testing.T) {
670+
type ErrorResponse struct {
671+
Code int `json:"code"`
672+
Message string `json:"message"`
673+
}
674+
675+
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
676+
w.Header().Set(hdrContentTypeKey, "application/json")
677+
w.WriteHeader(http.StatusForbidden)
678+
_, _ = w.Write([]byte(`{ "code": 403, "message": "forbidden error message" }`))
679+
})
680+
defer ts.Close()
681+
682+
c := dcnl()
683+
684+
t.Run("single request", func(t *testing.T) {
685+
res, err := c.R().
686+
SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")).
687+
SetResultError(&ErrorResponse{}).
688+
Post(ts.URL)
689+
690+
assertErrorIs(t, io.ErrClosedPipe, err)
691+
assertEqual(t, http.StatusForbidden, res.StatusCode())
692+
693+
er := res.ResultError().(*ErrorResponse)
694+
assertEqual(t, 403, er.Code)
695+
assertEqual(t, "forbidden error message", er.Message)
696+
})
697+
698+
t.Run("concurrent requests", func(t *testing.T) {
699+
concurrencyCount := 100
700+
wg := sync.WaitGroup{}
701+
for i := 0; i < concurrencyCount; i++ {
702+
wg.Add(1)
703+
go func() {
704+
defer wg.Done()
705+
res, _ := c.R().
706+
SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")).
707+
SetResultError(&ErrorResponse{}).
708+
Post(ts.URL)
709+
710+
er := res.ResultError().(*ErrorResponse)
711+
assertEqual(t, http.StatusForbidden, res.StatusCode())
712+
assertEqual(t, 403, er.Code)
713+
assertEqual(t, "forbidden error message", er.Message)
714+
}()
715+
}
716+
wg.Wait()
717+
})
718+
719+
}
720+
721+
func TestMultipartConcurrentRequests(t *testing.T) {
722+
ts := createFormPostServer(t)
723+
defer ts.Close()
724+
defer cleanupFiles(".testdata/upload")
725+
726+
c := dcnl()
727+
c.SetFormData(map[string]string{"zip_code": "00001", "city": "Los Angeles"})
728+
729+
concurrencyCount := 100
730+
wg := sync.WaitGroup{}
731+
for i := 0; i < concurrencyCount; i++ {
732+
wg.Add(1)
733+
go func() {
734+
defer wg.Done()
735+
res, err := c.R().
736+
SetFormData(map[string]string{
737+
"welcome1": "welcome value 1",
738+
"welcome2": "welcome value 2",
739+
"welcome3": "welcome value 3",
740+
}).
741+
SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")).
742+
Post(ts.URL + "/upload")
743+
744+
assertError(t, err)
745+
assertEqual(t, http.StatusOK, res.StatusCode())
746+
assertEqual(t, true, strings.Contains(res.String(), "test-img.png"))
747+
}()
748+
}
749+
wg.Wait()
750+
}
751+
671752
type returnValueTestWriter struct {
672753
}
673754

response.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,13 @@ func (r *Response) wrapContentDecompresser() error {
335335

336336
return nil
337337
}
338+
339+
func (r *Response) wrapError(err error, preserve bool) error {
340+
r.CascadeError = wrapErrors(err, r.CascadeError)
341+
if preserve {
342+
return nil
343+
}
344+
e := r.CascadeError
345+
r.CascadeError = nil
346+
return e
347+
}

resty_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ func createPostServer(t *testing.T) *httptest.Server {
355355

356356
func createFormPostServer(t *testing.T) *httptest.Server {
357357
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
358+
t.Logf("Content-Type: %v", r.Header.Get(hdrConnectionKey))
359+
358360
if r.Method == MethodPost {
359361
_ = r.ParseMultipartForm(10e6)
360362

@@ -406,9 +408,9 @@ func createFormPostServer(t *testing.T) *httptest.Server {
406408
defer func() {
407409
_ = f.Close()
408410
}()
409-
_, _ = io.Copy(f, infile)
411+
size, _ := io.Copy(f, infile)
410412

411-
_, _ = w.Write([]byte(fmt.Sprintf("File: %v, uploaded as: %v\n", hdr.Filename, fname)))
413+
_, _ = w.Write([]byte(fmt.Sprintf("File: %v, uploaded as: %v, size: %v\n", hdr.Filename, fname, size)))
412414
}
413415
}
414416

0 commit comments

Comments
 (0)