Skip to content

Commit 262cd58

Browse files
committed
close connections in case of write errors (#613)
1 parent a2df9d8 commit 262cd58

12 files changed

+421
-344
lines changed

async_processor.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,57 @@ import (
44
"github.com/bluenviron/gortsplib/v4/pkg/ringbuffer"
55
)
66

7-
// this struct contains a queue that allows to detach the routine that is reading a stream
7+
// this is an asynchronous queue processor
8+
// that allows to detach the routine that is reading a stream
89
// from the routine that is writing a stream.
910
type asyncProcessor struct {
11+
bufferSize int
12+
1013
running bool
1114
buffer *ringbuffer.RingBuffer
1215

13-
done chan struct{}
16+
chError chan error
1417
}
1518

16-
func (w *asyncProcessor) allocateBuffer(size int) {
17-
w.buffer, _ = ringbuffer.New(uint64(size))
19+
func (w *asyncProcessor) initialize() {
20+
w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
1821
}
1922

2023
func (w *asyncProcessor) start() {
2124
w.running = true
22-
w.done = make(chan struct{})
25+
w.chError = make(chan error)
2326
go w.run()
2427
}
2528

2629
func (w *asyncProcessor) stop() {
27-
if w.running {
28-
w.buffer.Close()
29-
<-w.done
30-
w.running = false
30+
if !w.running {
31+
panic("should not happen")
3132
}
33+
w.buffer.Close()
34+
<-w.chError
35+
w.running = false
3236
}
3337

3438
func (w *asyncProcessor) run() {
35-
defer close(w.done)
39+
err := w.runInner()
40+
w.chError <- err
41+
close(w.chError)
42+
}
3643

44+
func (w *asyncProcessor) runInner() error {
3745
for {
3846
tmp, ok := w.buffer.Pull()
3947
if !ok {
40-
return
48+
return nil
4149
}
4250

43-
tmp.(func())()
51+
err := tmp.(func() error)()
52+
if err != nil {
53+
return err
54+
}
4455
}
4556
}
4657

47-
func (w *asyncProcessor) push(cb func()) bool {
58+
func (w *asyncProcessor) push(cb func() error) bool {
4859
return w.buffer.Push(cb)
4960
}

client.go

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -335,22 +335,19 @@ type Client struct {
335335
keepalivePeriod time.Duration
336336
keepaliveTimer *time.Timer
337337
closeError error
338-
writer asyncProcessor
338+
writer *asyncProcessor
339339
reader *clientReader
340340
timeDecoder *rtptime.GlobalDecoder2
341341
mustClose bool
342342

343343
// in
344-
chOptions chan optionsReq
345-
chDescribe chan describeReq
346-
chAnnounce chan announceReq
347-
chSetup chan setupReq
348-
chPlay chan playReq
349-
chRecord chan recordReq
350-
chPause chan pauseReq
351-
chReadError chan error
352-
chReadResponse chan *base.Response
353-
chReadRequest chan *base.Request
344+
chOptions chan optionsReq
345+
chDescribe chan describeReq
346+
chAnnounce chan announceReq
347+
chSetup chan setupReq
348+
chPlay chan playReq
349+
chRecord chan recordReq
350+
chPause chan pauseReq
354351

355352
// out
356353
done chan struct{}
@@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error {
462459
c.chPlay = make(chan playReq)
463460
c.chRecord = make(chan recordReq)
464461
c.chPause = make(chan pauseReq)
465-
c.chReadError = make(chan error)
466-
c.chReadResponse = make(chan *base.Response)
467-
c.chReadRequest = make(chan *base.Request)
468462
c.done = make(chan struct{})
469463

470464
go c.run()
@@ -530,6 +524,34 @@ func (c *Client) run() {
530524

531525
func (c *Client) runInner() error {
532526
for {
527+
chReaderResponse := func() chan *base.Response {
528+
if c.reader != nil {
529+
return c.reader.chResponse
530+
}
531+
return nil
532+
}()
533+
534+
chReaderRequest := func() chan *base.Request {
535+
if c.reader != nil {
536+
return c.reader.chRequest
537+
}
538+
return nil
539+
}()
540+
541+
chReaderError := func() chan error {
542+
if c.reader != nil {
543+
return c.reader.chError
544+
}
545+
return nil
546+
}()
547+
548+
chWriterError := func() chan error {
549+
if c.writer != nil {
550+
return c.writer.chError
551+
}
552+
return nil
553+
}()
554+
533555
select {
534556
case req := <-c.chOptions:
535557
res, err := c.doOptions(req.url)
@@ -601,15 +623,18 @@ func (c *Client) runInner() error {
601623
}
602624
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
603625

604-
case err := <-c.chReadError:
626+
case err := <-chWriterError:
627+
return err
628+
629+
case err := <-chReaderError:
605630
c.reader = nil
606631
return err
607632

608-
case res := <-c.chReadResponse:
633+
case res := <-chReaderResponse:
609634
c.OnResponse(res)
610635
// these are responses to keepalives, ignore them.
611636

612-
case req := <-c.chReadRequest:
637+
case req := <-chReaderRequest:
613638
err := c.handleServerRequest(req)
614639
if err != nil {
615640
return err
@@ -630,19 +655,19 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
630655
case <-t.C:
631656
return nil, liberrors.ErrClientRequestTimedOut{}
632657

633-
case err := <-c.chReadError:
658+
case err := <-c.reader.chError:
634659
c.reader = nil
635660
return nil, err
636661

637-
case res := <-c.chReadResponse:
662+
case res := <-c.reader.chResponse:
638663
c.OnResponse(res)
639664

640665
// accept response if CSeq equals request CSeq, or if CSeq is not present
641666
if cseq, ok := res.Header["CSeq"]; !ok || len(cseq) != 1 || strings.TrimSpace(cseq[0]) == requestCseqStr {
642667
return res, nil
643668
}
644669

645-
case req := <-c.chReadRequest:
670+
case req := <-c.reader.chRequest:
646671
err := c.handleServerRequest(req)
647672
if err != nil {
648673
return nil, err
@@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error {
682707

683708
func (c *Client) doClose() {
684709
if c.state == clientStatePlay || c.state == clientStateRecord {
685-
c.stopWriter()
686-
c.stopReadRoutines()
710+
c.writer.stop()
711+
c.stopTransportRoutines()
687712
}
688713

689714
if c.nconn != nil && c.baseURL != nil {
@@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
808833
return c.doSetup(baseURL, medi, 0, 0)
809834
}
810835

811-
func (c *Client) startReadRoutines() {
836+
func (c *Client) startTransportRoutines() {
812837
// allocate writer here because it's needed by RTCP receiver / sender
813838
if c.state == clientStateRecord || c.backChannelSetupped {
814-
c.writer.allocateBuffer(c.WriteQueueSize)
839+
c.writer = &asyncProcessor{
840+
bufferSize: c.WriteQueueSize,
841+
}
842+
c.writer.initialize()
815843
} else {
816844
// when reading, buffer is only used to send RTCP receiver reports,
817845
// that are much smaller than RTP packets and are sent at a fixed interval.
818846
// decrease RAM consumption by allocating less buffers.
819-
c.writer.allocateBuffer(8)
847+
c.writer = &asyncProcessor{
848+
bufferSize: 8,
849+
}
850+
c.writer.initialize()
820851
}
821852

822853
c.timeDecoder = rtptime.NewGlobalDecoder2()
@@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() {
848879
}
849880
}
850881

851-
func (c *Client) stopReadRoutines() {
882+
func (c *Client) stopTransportRoutines() {
852883
if c.reader != nil {
853884
c.reader.setAllowInterleavedFrames(false)
854885
}
@@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() {
861892
}
862893

863894
c.timeDecoder = nil
864-
}
865-
866-
func (c *Client) startWriter() {
867-
c.writer.start()
868-
}
869895

870-
func (c *Client) stopWriter() {
871-
c.writer.stop()
896+
c.writer = nil
872897
}
873898

874899
func (c *Client) connOpen() error {
@@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
16371662
}
16381663

16391664
c.state = clientStatePlay
1640-
c.startReadRoutines()
1665+
c.startTransportRoutines()
16411666

16421667
// Range is mandatory in Parrot Streaming Server
16431668
if ra == nil {
@@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
16621687
Header: header,
16631688
}, false)
16641689
if err != nil {
1665-
c.stopReadRoutines()
1690+
c.stopTransportRoutines()
16661691
c.state = clientStatePrePlay
16671692
return nil, err
16681693
}
16691694

16701695
if res.StatusCode != base.StatusOK {
1671-
c.stopReadRoutines()
1696+
c.stopTransportRoutines()
16721697
c.state = clientStatePrePlay
16731698
return nil, liberrors.ErrClientBadStatusCode{
16741699
Code: res.StatusCode, Message: res.StatusMessage,
@@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
16891714
}
16901715
}
16911716

1692-
c.startWriter()
1717+
c.writer.start()
16931718
c.lastRange = ra
16941719

16951720
return res, nil
@@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) {
17181743
}
17191744

17201745
c.state = clientStateRecord
1721-
c.startReadRoutines()
1746+
c.startTransportRoutines()
17221747

17231748
res, err := c.do(&base.Request{
17241749
Method: base.Record,
17251750
URL: c.baseURL,
17261751
}, false)
17271752
if err != nil {
1728-
c.stopReadRoutines()
1753+
c.stopTransportRoutines()
17291754
c.state = clientStatePreRecord
17301755
return nil, err
17311756
}
17321757

17331758
if res.StatusCode != base.StatusOK {
1734-
c.stopReadRoutines()
1759+
c.stopTransportRoutines()
17351760
c.state = clientStatePreRecord
17361761
return nil, liberrors.ErrClientBadStatusCode{
17371762
Code: res.StatusCode, Message: res.StatusMessage,
17381763
}
17391764
}
17401765

1741-
c.startWriter()
1766+
c.writer.start()
17421767

17431768
return nil, nil
17441769
}
@@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) {
17661791
return nil, err
17671792
}
17681793

1769-
c.stopWriter()
1794+
c.writer.stop()
17701795

17711796
res, err := c.do(&base.Request{
17721797
Method: base.Pause,
17731798
URL: c.baseURL,
17741799
}, false)
17751800
if err != nil {
1776-
c.startWriter()
1801+
c.writer.start()
17771802
return nil, err
17781803
}
17791804

17801805
if res.StatusCode != base.StatusOK {
1781-
c.startWriter()
1806+
c.writer.start()
17821807
return nil, liberrors.ErrClientBadStatusCode{
17831808
Code: res.StatusCode, Message: res.StatusMessage,
17841809
}
17851810
}
17861811

1787-
c.stopReadRoutines()
1812+
c.stopTransportRoutines()
17881813

17891814
switch c.state {
17901815
case clientStatePlay:
@@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
19291954
ct := cm.formats[pkt.PayloadType]
19301955
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
19311956
}
1932-
1933-
func (c *Client) readResponse(res *base.Response) {
1934-
c.chReadResponse <- res
1935-
}
1936-
1937-
func (c *Client) readRequest(req *base.Request) {
1938-
c.chReadRequest <- req
1939-
}
1940-
1941-
func (c *Client) readError(err error) {
1942-
c.chReadError <- err
1943-
}

client_format.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ func (cf *clientFormat) stop() {
7474
func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
7575
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))
7676

77-
ok := cf.cm.c.writer.push(func() {
78-
cf.cm.writePacketRTPInQueue(byts)
77+
ok := cf.cm.c.writer.push(func() error {
78+
return cf.cm.writePacketRTPInQueue(byts)
7979
})
8080
if !ok {
8181
return liberrors.ErrClientWriteQueueFull{}

0 commit comments

Comments
 (0)