@@ -7,6 +7,7 @@ package resty
77
88import (
99 "bytes"
10+ "context"
1011 "fmt"
1112 "io"
1213 "mime"
@@ -238,7 +239,7 @@ func createRawRequest(c *Client, r *Request) (err error) {
238239 }
239240
240241 // get the context reference back from underlying RawRequest
241- r .ctx = r .RawRequest .Context ()
242+ r .SetContext ( r .RawRequest .Context () )
242243
243244 // Assign close connection option
244245 r .RawRequest .Close = r .CloseConnection
@@ -289,105 +290,138 @@ func addCredentials(c *Client, r *Request) error {
289290 return nil
290291}
291292
292- func handleMultipart (c * Client , r * Request ) error {
293- for k , v := range c .FormData () {
294- if _ , ok := r .FormData [k ]; ok {
295- continue
296- }
297- r .FormData [k ] = v [:]
298- }
299-
300- mfLen := len (r .multipartFields )
301- if mfLen == 0 {
302- r .bodyBuf = acquireBuffer ()
303- mw := multipart .NewWriter (r .bodyBuf )
293+ var multipartWriteField = func (w * multipart.Writer , name , value string ) error {
294+ return w .WriteField (name , value )
295+ }
304296
305- // set boundary if it is provided by the user
306- if ! isStringEmpty (r .multipartBoundary ) {
307- if err := mw .SetBoundary (r .multipartBoundary ); err != nil {
297+ var multipartWriteFormData = func (w * multipart.Writer , r * Request ) error {
298+ for k , v := range r .FormData {
299+ for _ , iv := range v {
300+ if err := multipartWriteField (w , k , iv ); err != nil {
308301 return err
309302 }
310303 }
304+ }
305+ return nil
306+ }
311307
312- if err := r .writeFormData (mw ); err != nil {
313- return err
314- }
315-
316- r .Header .Set (hdrContentTypeKey , mw .FormDataContentType ())
317- closeq (mw )
308+ var multipartCreatePart = func (w * multipart.Writer , h textproto.MIMEHeader ) (io.Writer , error ) {
309+ return w .CreatePart (h )
310+ }
318311
312+ var multipartSetBoundary = func (w * multipart.Writer , r * Request ) error {
313+ if isStringEmpty (r .multipartBoundary ) {
319314 return nil
320315 }
316+ return w .SetBoundary (r .multipartBoundary )
317+ }
321318
322- // multipart streaming
323- bodyReader , bodyWriter := io .Pipe ()
324- mw := multipart .NewWriter (bodyWriter )
325- r .Body = bodyReader
326- r .multipartErrChan = make (chan error , 1 )
319+ func handleMultipartFormData (r * Request ) error {
320+ r .bodyBuf = acquireBuffer ()
321+ mw := multipart .NewWriter (r .bodyBuf )
322+ defer mw .Close ()
327323
328- // set boundary if it is provided by the user
329- if ! isStringEmpty (r .multipartBoundary ) {
330- if err := mw .SetBoundary (r .multipartBoundary ); err != nil {
331- return err
332- }
324+ // set custom multipart boundary if exists
325+ if err := multipartSetBoundary (mw , r ); err != nil {
326+ return err
333327 }
334328
335- go func () {
336- defer close (r .multipartErrChan )
337- if err := createMultipart (mw , r ); err != nil {
338- r .multipartErrChan <- err
339- }
340- closeq (mw )
341- closeq (bodyWriter )
342- }()
343-
344329 r .Header .Set (hdrContentTypeKey , mw .FormDataContentType ())
345- return nil
346- }
347330
348- var mpCreatePart = func (w * multipart.Writer , h textproto.MIMEHeader ) (io.Writer , error ) {
349- return w .CreatePart (h )
331+ return multipartWriteFormData (mw , r )
350332}
351333
352- func createMultipart (w * multipart.Writer , r * Request ) error {
353- if err := r .writeFormData (w ); err != nil {
354- return err
334+ func handleMultipart (c * Client , r * Request ) error {
335+ for k , v := range c .FormData () {
336+ if _ , ok := r .FormData [k ]; ok {
337+ continue
338+ }
339+ r .FormData [k ] = v [:]
355340 }
356341
342+ if len (r .multipartFields ) == 0 {
343+ return handleMultipartFormData (r )
344+ }
345+
346+ // pre-process multipart fields to catch possible errors
357347 for _ , mf := range r .multipartFields {
358- if len (mf .Values ) > 0 {
359- for _ , v := range mf .Values {
360- w .WriteField (mf .Name , v )
361- }
348+ if mf .isValues () {
362349 continue
363350 }
364351
365- if err := mf .openFileIfRequired (); err != nil {
352+ if err := mf .openFile (); err != nil {
366353 return err
367354 }
368355
369- p := make ([]byte , 512 )
370- size , err := mf .Reader .Read (p )
371- if err != nil && err != io .EOF {
356+ if err := mf .detectContentType (); err != nil {
372357 return err
373358 }
374- // auto detect content type if empty
375- if isStringEmpty (mf .ContentType ) {
376- mf .ContentType = http .DetectContentType (p [:size ])
377- }
359+ }
378360
379- partWriter , err := mpCreatePart (w , mf .createHeader ())
380- if err != nil {
381- return err
361+ // multipart streaming
362+ br , bw := io .Pipe ()
363+ mw := multipart .NewWriter (bw )
364+ r .Body = br
365+
366+ // set custom multipart boundary if exists
367+ if err := multipartSetBoundary (mw , r ); err != nil {
368+ closeq (bw )
369+ return err
370+ }
371+
372+ r .Header .Set (hdrContentTypeKey , mw .FormDataContentType ())
373+
374+ r .multipartErrChan = make (chan error , 1 )
375+ go func () {
376+ defer close (r .multipartErrChan )
377+ defer func () {
378+ if err := mw .Close (); err != nil {
379+ r .multipartErrChan <- err
380+ }
381+ if err := bw .Close (); err != nil {
382+ r .multipartErrChan <- err
383+ }
384+ }()
385+
386+ if err := multipartWriteFormData (mw , r ); err != nil {
387+ r .multipartErrChan <- err
388+ return
382389 }
383390
384- partWriter = mf .wrapProgressCallbackIfPresent (partWriter )
385- partWriter .Write (p [:size ])
391+ ctx , cancel := context .WithCancel (r .Context ())
392+ r .multipartCancelFunc = cancel
393+ for _ , mf := range r .multipartFields {
394+ if mf .isValues () {
395+ for _ , v := range mf .Values {
396+ if err := multipartWriteField (mw , mf .Name , v ); err != nil {
397+ r .multipartErrChan <- err
398+ return
399+ }
400+ }
401+ continue
402+ }
386403
387- if _ , err = ioCopy (partWriter , mf .Reader ); err != nil {
388- return err
404+ partWriter , err := multipartCreatePart (mw , mf .createHeader ())
405+ if err != nil {
406+ r .multipartErrChan <- err
407+ return
408+ }
409+
410+ partWriter = mf .wrapProgressCallbackIfPresent (partWriter )
411+ if len (mf .tempBuf ) > 0 {
412+ if _ , err = partWriter .Write (mf .tempBuf ); err != nil {
413+ r .multipartErrChan <- err
414+ return
415+ }
416+ }
417+
418+ reader := & gracefulStopReader {ctx : ctx , r : mf .Reader }
419+ if _ , err = ioCopy (partWriter , reader ); err != nil {
420+ r .multipartErrChan <- err
421+ return
422+ }
389423 }
390- }
424+ }()
391425
392426 return nil
393427}
@@ -482,7 +516,8 @@ func handleRequestBody(c *Client, r *Request) error {
482516// based on registered HTTP response `Content-Type` decoder, see [Client.AddContentTypeDecoder];
483517// if [Request.SetResult], [Request.SetResultError], or [Client.SetResultError] is used
484518func AutoParseResponseMiddleware (c * Client , res * Response ) (err error ) {
485- if res .CascadeError != nil || res .Request .DoNotParseResponse {
519+ if (res .CascadeError != nil && (res .Request .isMultiPart && res .StatusCode () == 0 )) ||
520+ res .Request .DoNotParseResponse {
486521 return // move on
487522 }
488523
0 commit comments