Skip to content

Commit d6e2578

Browse files
authored
fix: improve flow and error handling in the multipart request #1030 (#1101)
1 parent 09ed804 commit d6e2578

File tree

9 files changed

+398
-145
lines changed

9 files changed

+398
-145
lines changed

client.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ func (c *Client) R() *Request {
710710
AllowNonIdempotentRetry: c.allowNonIdempotentRetry,
711711
HeaderAuthorizationKey: c.headerAuthorizationKey,
712712

713+
mu: new(sync.Mutex),
713714
client: c,
714715
baseURL: c.baseURL,
715716
multipartFields: make([]*MultipartField, 0),
@@ -2372,36 +2373,42 @@ func (c *Client) execute(req *Request) (*Response, error) {
23722373

23732374
req.Time = time.Now()
23742375
resp, err := c.Client().Do(req.withTimeout())
2376+
// Cancel multipart context for io.Copy to stop reading/writing further
2377+
if req.isMultiPart && req.multipartCancelFunc != nil {
2378+
req.multipartCancelFunc()
2379+
}
23752380

23762381
response := &Response{Request: req, RawResponse: resp}
23772382
response.setReceivedAt()
23782383
if err != nil {
23792384
return response, err
23802385
}
2381-
if req.multipartErrChan != nil {
2382-
if err = <-req.multipartErrChan; err != nil {
2383-
return response, err
2386+
if req.isMultiPart && req.multipartErrChan != nil {
2387+
// read all multipart errors from channel
2388+
for err = range req.multipartErrChan {
2389+
response.CascadeError = wrapErrors(err, response.CascadeError)
23842390
}
23852391
}
2392+
23862393
if resp != nil {
23872394
if c.circuitBreaker != nil {
23882395
c.circuitBreaker.applyPolicies(resp)
23892396
}
23902397

23912398
response.Body = resp.Body
23922399
if err = response.wrapContentDecompresser(); err != nil {
2393-
return response, err
2400+
return response, response.wrapError(err, false)
23942401
}
23952402

23962403
response.wrapLimitReadCloser()
2397-
}
23982404

2399-
if !req.DoNotParseResponse {
2400-
if req.ResponseBodyUnlimitedReads || req.Debug {
2401-
response.wrapCopyReadCloser()
2405+
if !req.DoNotParseResponse {
2406+
if req.ResponseBodyUnlimitedReads || req.Debug {
2407+
response.wrapCopyReadCloser()
24022408

2403-
if err = response.readAll(); err != nil {
2404-
return response, err
2409+
if err = response.readAll(); err != nil {
2410+
return response, response.wrapError(err, false)
2411+
}
24052412
}
24062413
}
24072414
}
@@ -2415,8 +2422,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
24152422
}
24162423
}
24172424

2418-
err = response.CascadeError
2419-
return response, err
2425+
return response, response.wrapError(nil, false)
24202426
}
24212427

24222428
// getting TLS client config if not exists then create one

middleware.go

Lines changed: 104 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package resty
77

88
import (
99
"bytes"
10+
"context"
1011
"fmt"
1112
"io"
1213
"mime"
@@ -238,7 +239,7 @@ func createRawRequest(c *Client, r *Request) (err error) {
238239
}
239240

240241
// get the context reference back from underlying RawRequest
241-
r.ctx = r.RawRequest.Context()
242+
r.SetContext(r.RawRequest.Context())
242243

243244
// Assign close connection option
244245
r.RawRequest.Close = r.CloseConnection
@@ -289,105 +290,138 @@ func addCredentials(c *Client, r *Request) error {
289290
return nil
290291
}
291292

292-
func handleMultipart(c *Client, r *Request) error {
293-
for k, v := range c.FormData() {
294-
if _, ok := r.FormData[k]; ok {
295-
continue
296-
}
297-
r.FormData[k] = v[:]
298-
}
299-
300-
mfLen := len(r.multipartFields)
301-
if mfLen == 0 {
302-
r.bodyBuf = acquireBuffer()
303-
mw := multipart.NewWriter(r.bodyBuf)
293+
var multipartWriteField = func(w *multipart.Writer, name, value string) error {
294+
return w.WriteField(name, value)
295+
}
304296

305-
// set boundary if it is provided by the user
306-
if !isStringEmpty(r.multipartBoundary) {
307-
if err := mw.SetBoundary(r.multipartBoundary); err != nil {
297+
var multipartWriteFormData = func(w *multipart.Writer, r *Request) error {
298+
for k, v := range r.FormData {
299+
for _, iv := range v {
300+
if err := multipartWriteField(w, k, iv); err != nil {
308301
return err
309302
}
310303
}
304+
}
305+
return nil
306+
}
311307

312-
if err := r.writeFormData(mw); err != nil {
313-
return err
314-
}
315-
316-
r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
317-
closeq(mw)
308+
var multipartCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) {
309+
return w.CreatePart(h)
310+
}
318311

312+
var multipartSetBoundary = func(w *multipart.Writer, r *Request) error {
313+
if isStringEmpty(r.multipartBoundary) {
319314
return nil
320315
}
316+
return w.SetBoundary(r.multipartBoundary)
317+
}
321318

322-
// multipart streaming
323-
bodyReader, bodyWriter := io.Pipe()
324-
mw := multipart.NewWriter(bodyWriter)
325-
r.Body = bodyReader
326-
r.multipartErrChan = make(chan error, 1)
319+
func handleMultipartFormData(r *Request) error {
320+
r.bodyBuf = acquireBuffer()
321+
mw := multipart.NewWriter(r.bodyBuf)
322+
defer mw.Close()
327323

328-
// set boundary if it is provided by the user
329-
if !isStringEmpty(r.multipartBoundary) {
330-
if err := mw.SetBoundary(r.multipartBoundary); err != nil {
331-
return err
332-
}
324+
// set custom multipart boundary if exists
325+
if err := multipartSetBoundary(mw, r); err != nil {
326+
return err
333327
}
334328

335-
go func() {
336-
defer close(r.multipartErrChan)
337-
if err := createMultipart(mw, r); err != nil {
338-
r.multipartErrChan <- err
339-
}
340-
closeq(mw)
341-
closeq(bodyWriter)
342-
}()
343-
344329
r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
345-
return nil
346-
}
347330

348-
var mpCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) {
349-
return w.CreatePart(h)
331+
return multipartWriteFormData(mw, r)
350332
}
351333

352-
func createMultipart(w *multipart.Writer, r *Request) error {
353-
if err := r.writeFormData(w); err != nil {
354-
return err
334+
func handleMultipart(c *Client, r *Request) error {
335+
for k, v := range c.FormData() {
336+
if _, ok := r.FormData[k]; ok {
337+
continue
338+
}
339+
r.FormData[k] = v[:]
355340
}
356341

342+
if len(r.multipartFields) == 0 {
343+
return handleMultipartFormData(r)
344+
}
345+
346+
// pre-process multipart fields to catch possible errors
357347
for _, mf := range r.multipartFields {
358-
if len(mf.Values) > 0 {
359-
for _, v := range mf.Values {
360-
w.WriteField(mf.Name, v)
361-
}
348+
if mf.isValues() {
362349
continue
363350
}
364351

365-
if err := mf.openFileIfRequired(); err != nil {
352+
if err := mf.openFile(); err != nil {
366353
return err
367354
}
368355

369-
p := make([]byte, 512)
370-
size, err := mf.Reader.Read(p)
371-
if err != nil && err != io.EOF {
356+
if err := mf.detectContentType(); err != nil {
372357
return err
373358
}
374-
// auto detect content type if empty
375-
if isStringEmpty(mf.ContentType) {
376-
mf.ContentType = http.DetectContentType(p[:size])
377-
}
359+
}
378360

379-
partWriter, err := mpCreatePart(w, mf.createHeader())
380-
if err != nil {
381-
return err
361+
// multipart streaming
362+
br, bw := io.Pipe()
363+
mw := multipart.NewWriter(bw)
364+
r.Body = br
365+
366+
// set custom multipart boundary if exists
367+
if err := multipartSetBoundary(mw, r); err != nil {
368+
closeq(bw)
369+
return err
370+
}
371+
372+
r.Header.Set(hdrContentTypeKey, mw.FormDataContentType())
373+
374+
r.multipartErrChan = make(chan error, 1)
375+
go func() {
376+
defer close(r.multipartErrChan)
377+
defer func() {
378+
if err := mw.Close(); err != nil {
379+
r.multipartErrChan <- err
380+
}
381+
if err := bw.Close(); err != nil {
382+
r.multipartErrChan <- err
383+
}
384+
}()
385+
386+
if err := multipartWriteFormData(mw, r); err != nil {
387+
r.multipartErrChan <- err
388+
return
382389
}
383390

384-
partWriter = mf.wrapProgressCallbackIfPresent(partWriter)
385-
partWriter.Write(p[:size])
391+
ctx, cancel := context.WithCancel(r.Context())
392+
r.multipartCancelFunc = cancel
393+
for _, mf := range r.multipartFields {
394+
if mf.isValues() {
395+
for _, v := range mf.Values {
396+
if err := multipartWriteField(mw, mf.Name, v); err != nil {
397+
r.multipartErrChan <- err
398+
return
399+
}
400+
}
401+
continue
402+
}
386403

387-
if _, err = ioCopy(partWriter, mf.Reader); err != nil {
388-
return err
404+
partWriter, err := multipartCreatePart(mw, mf.createHeader())
405+
if err != nil {
406+
r.multipartErrChan <- err
407+
return
408+
}
409+
410+
partWriter = mf.wrapProgressCallbackIfPresent(partWriter)
411+
if len(mf.tempBuf) > 0 {
412+
if _, err = partWriter.Write(mf.tempBuf); err != nil {
413+
r.multipartErrChan <- err
414+
return
415+
}
416+
}
417+
418+
reader := &gracefulStopReader{ctx: ctx, r: mf.Reader}
419+
if _, err = ioCopy(partWriter, reader); err != nil {
420+
r.multipartErrChan <- err
421+
return
422+
}
389423
}
390-
}
424+
}()
391425

392426
return nil
393427
}
@@ -482,7 +516,8 @@ func handleRequestBody(c *Client, r *Request) error {
482516
// based on registered HTTP response `Content-Type` decoder, see [Client.AddContentTypeDecoder];
483517
// if [Request.SetResult], [Request.SetResultError], or [Client.SetResultError] is used
484518
func AutoParseResponseMiddleware(c *Client, res *Response) (err error) {
485-
if res.CascadeError != nil || res.Request.DoNotParseResponse {
519+
if (res.CascadeError != nil && (res.Request.isMultiPart && res.StatusCode() == 0)) ||
520+
res.Request.DoNotParseResponse {
486521
return // move on
487522
}
488523

0 commit comments

Comments
 (0)