Skip to content

Commit 8e85104

Browse files
Improve pubsub (redis#1764)
* Improve pubsub Signed-off-by: monkey92t <[email protected]> * Extract code to channel struct and tweak API * Move chanSendTimeout to channel * Cleanup health check * Add WithChannelSendTimeout and tweak comments * clear notes Signed-off-by: monkey92t <[email protected]> Co-authored-by: Vladimir Mihailenco <[email protected]>
1 parent f83600d commit 8e85104

File tree

2 files changed

+141
-89
lines changed

2 files changed

+141
-89
lines changed

pubsub.go

Lines changed: 122 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package redis
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"strings"
87
"sync"
@@ -13,13 +12,6 @@ import (
1312
"github.com/go-redis/redis/v8/internal/proto"
1413
)
1514

16-
const (
17-
pingTimeout = time.Second
18-
chanSendTimeout = time.Minute
19-
)
20-
21-
var errPingTimeout = errors.New("redis: ping timeout")
22-
2315
// PubSub implements Pub/Sub commands as described in
2416
// http://redis.io/topics/pubsub. Message receiving is NOT safe
2517
// for concurrent use by multiple goroutines.
@@ -43,9 +35,12 @@ type PubSub struct {
4335
cmd *Cmd
4436

4537
chOnce sync.Once
46-
msgCh chan *Message
47-
allCh chan interface{}
48-
ping chan struct{}
38+
msgCh *channel
39+
allCh *channel
40+
}
41+
42+
func (c *PubSub) init() {
43+
c.exit = make(chan struct{})
4944
}
5045

5146
func (c *PubSub) String() string {
@@ -54,10 +49,6 @@ func (c *PubSub) String() string {
5449
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
5550
}
5651

57-
func (c *PubSub) init() {
58-
c.exit = make(chan struct{})
59-
}
60-
6152
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
6253
c.mu.Lock()
6354
cn, err := c.conn(ctx, nil)
@@ -418,110 +409,160 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
418409
}
419410
}
420411

412+
func (c *PubSub) getContext() context.Context {
413+
if c.cmd != nil {
414+
return c.cmd.ctx
415+
}
416+
return context.Background()
417+
}
418+
419+
//------------------------------------------------------------------------------
420+
421421
// Channel returns a Go channel for concurrently receiving messages.
422422
// The channel is closed together with the PubSub. If the Go channel
423423
// is blocked full for 30 seconds the message is dropped.
424424
// Receive* APIs can not be used after channel is created.
425425
//
426426
// go-redis periodically sends ping messages to test connection health
427427
// and re-subscribes if ping can not not received for 30 seconds.
428-
func (c *PubSub) Channel() <-chan *Message {
429-
return c.ChannelSize(100)
430-
}
431-
432-
// ChannelSize is like Channel, but creates a Go channel
433-
// with specified buffer size.
434-
func (c *PubSub) ChannelSize(size int) <-chan *Message {
428+
func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
435429
c.chOnce.Do(func() {
436-
c.initPing()
437-
c.initMsgChan(size)
430+
c.msgCh = newChannel(c, opts...)
431+
c.msgCh.initMsgChan()
438432
})
439433
if c.msgCh == nil {
440434
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
441435
panic(err)
442436
}
443-
if cap(c.msgCh) != size {
444-
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
445-
panic(err)
446-
}
447-
return c.msgCh
437+
return c.msgCh.msgCh
438+
}
439+
440+
// ChannelSize is like Channel, but creates a Go channel
441+
// with specified buffer size.
442+
//
443+
// Deprecated: use Channel(WithChannelSize(size)), remove in v9.
444+
func (c *PubSub) ChannelSize(size int) <-chan *Message {
445+
return c.Channel(WithChannelSize(size))
448446
}
449447

450448
// ChannelWithSubscriptions is like Channel, but message type can be either
451449
// *Subscription or *Message. Subscription messages can be used to detect
452450
// reconnections.
453451
//
454452
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
455-
func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} {
453+
func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} {
456454
c.chOnce.Do(func() {
457-
c.initPing()
458-
c.initAllChan(size)
455+
c.allCh = newChannel(c, WithChannelSize(size))
456+
c.allCh.initAllChan()
459457
})
460458
if c.allCh == nil {
461459
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
462460
panic(err)
463461
}
464-
if cap(c.allCh) != size {
465-
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
466-
panic(err)
462+
return c.allCh.allCh
463+
}
464+
465+
type ChannelOption func(c *channel)
466+
467+
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
468+
//
469+
// The default is 100 messages.
470+
func WithChannelSize(size int) ChannelOption {
471+
return func(c *channel) {
472+
c.chanSize = size
467473
}
468-
return c.allCh
469474
}
470475

471-
func (c *PubSub) getContext() context.Context {
472-
if c.cmd != nil {
473-
return c.cmd.ctx
476+
// WithChannelHealthCheckInterval specifies the health check interval.
477+
// PubSub will ping Redis Server if it does not receive any messages within the interval.
478+
// To disable health check, use zero interval.
479+
//
480+
// The default is 3 seconds.
481+
func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
482+
return func(c *channel) {
483+
c.checkInterval = d
474484
}
475-
return context.Background()
476485
}
477486

478-
func (c *PubSub) initPing() {
487+
// WithChannelSendTimeout specifies that channel send timeout after which
488+
// the message is dropped.
489+
//
490+
// The default is 60 seconds.
491+
func WithChannelSendTimeout(d time.Duration) ChannelOption {
492+
return func(c *channel) {
493+
c.chanSendTimeout = d
494+
}
495+
}
496+
497+
type channel struct {
498+
pubSub *PubSub
499+
500+
msgCh chan *Message
501+
allCh chan interface{}
502+
ping chan struct{}
503+
504+
chanSize int
505+
chanSendTimeout time.Duration
506+
checkInterval time.Duration
507+
}
508+
509+
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
510+
c := &channel{
511+
pubSub: pubSub,
512+
513+
chanSize: 100,
514+
chanSendTimeout: time.Minute,
515+
checkInterval: 3 * time.Second,
516+
}
517+
for _, opt := range opts {
518+
opt(c)
519+
}
520+
if c.checkInterval > 0 {
521+
c.initHealthCheck()
522+
}
523+
return c
524+
}
525+
526+
func (c *channel) initHealthCheck() {
479527
ctx := context.TODO()
480528
c.ping = make(chan struct{}, 1)
529+
481530
go func() {
482531
timer := time.NewTimer(time.Minute)
483532
timer.Stop()
484533

485-
healthy := true
486534
for {
487-
timer.Reset(pingTimeout)
535+
timer.Reset(c.checkInterval)
488536
select {
489537
case <-c.ping:
490-
healthy = true
491538
if !timer.Stop() {
492539
<-timer.C
493540
}
494541
case <-timer.C:
495-
pingErr := c.Ping(ctx)
496-
if healthy {
497-
healthy = false
498-
} else {
499-
if pingErr == nil {
500-
pingErr = errPingTimeout
501-
}
502-
c.mu.Lock()
503-
c.reconnect(ctx, pingErr)
504-
healthy = true
505-
c.mu.Unlock()
542+
if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
543+
c.pubSub.mu.Lock()
544+
c.pubSub.reconnect(ctx, pingErr)
545+
c.pubSub.mu.Unlock()
506546
}
507-
case <-c.exit:
547+
case <-c.pubSub.exit:
508548
return
509549
}
510550
}
511551
}()
512552
}
513553

514554
// initMsgChan must be in sync with initAllChan.
515-
func (c *PubSub) initMsgChan(size int) {
555+
func (c *channel) initMsgChan() {
516556
ctx := context.TODO()
517-
c.msgCh = make(chan *Message, size)
557+
c.msgCh = make(chan *Message, c.chanSize)
558+
518559
go func() {
519560
timer := time.NewTimer(time.Minute)
520561
timer.Stop()
521562

522563
var errCount int
523564
for {
524-
msg, err := c.Receive(ctx)
565+
msg, err := c.pubSub.Receive(ctx)
525566
if err != nil {
526567
if err == pool.ErrClosed {
527568
close(c.msgCh)
@@ -548,38 +589,36 @@ func (c *PubSub) initMsgChan(size int) {
548589
case *Pong:
549590
// Ignore.
550591
case *Message:
551-
timer.Reset(chanSendTimeout)
592+
timer.Reset(c.chanSendTimeout)
552593
select {
553594
case c.msgCh <- msg:
554595
if !timer.Stop() {
555596
<-timer.C
556597
}
557598
case <-timer.C:
558599
internal.Logger.Printf(
559-
c.getContext(),
560-
"redis: %s channel is full for %s (message is dropped)",
561-
c,
562-
chanSendTimeout,
563-
)
600+
ctx, "redis: %s channel is full for %s (message is dropped)",
601+
c, c.chanSendTimeout)
564602
}
565603
default:
566-
internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg)
604+
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
567605
}
568606
}
569607
}()
570608
}
571609

572610
// initAllChan must be in sync with initMsgChan.
573-
func (c *PubSub) initAllChan(size int) {
611+
func (c *channel) initAllChan() {
574612
ctx := context.TODO()
575-
c.allCh = make(chan interface{}, size)
613+
c.allCh = make(chan interface{}, c.chanSize)
614+
576615
go func() {
577-
timer := time.NewTimer(pingTimeout)
616+
timer := time.NewTimer(time.Minute)
578617
timer.Stop()
579618

580619
var errCount int
581620
for {
582-
msg, err := c.Receive(ctx)
621+
msg, err := c.pubSub.Receive(ctx)
583622
if err != nil {
584623
if err == pool.ErrClosed {
585624
close(c.allCh)
@@ -601,29 +640,23 @@ func (c *PubSub) initAllChan(size int) {
601640
}
602641

603642
switch msg := msg.(type) {
604-
case *Subscription:
605-
c.sendMessage(msg, timer)
606643
case *Pong:
607644
// Ignore.
608-
case *Message:
609-
c.sendMessage(msg, timer)
645+
case *Subscription, *Message:
646+
timer.Reset(c.chanSendTimeout)
647+
select {
648+
case c.allCh <- msg:
649+
if !timer.Stop() {
650+
<-timer.C
651+
}
652+
case <-timer.C:
653+
internal.Logger.Printf(
654+
ctx, "redis: %s channel is full for %s (message is dropped)",
655+
c, c.chanSendTimeout)
656+
}
610657
default:
611-
internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg)
658+
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
612659
}
613660
}
614661
}()
615662
}
616-
617-
func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) {
618-
timer.Reset(pingTimeout)
619-
select {
620-
case c.allCh <- msg:
621-
if !timer.Stop() {
622-
<-timer.C
623-
}
624-
case <-timer.C:
625-
internal.Logger.Printf(
626-
c.getContext(),
627-
"redis: %s channel is full for %s (message is dropped)", c, pingTimeout)
628-
}
629-
}

pubsub_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,4 +473,23 @@ var _ = Describe("PubSub", func() {
473473
Fail("timeout")
474474
}
475475
})
476+
477+
It("should ChannelMessage", func() {
478+
pubsub := client.Subscribe(ctx, "mychannel")
479+
defer pubsub.Close()
480+
481+
ch := pubsub.Channel(
482+
redis.WithChannelSize(10),
483+
redis.WithChannelHealthCheckInterval(time.Second),
484+
)
485+
486+
text := "test channel message"
487+
err := client.Publish(ctx, "mychannel", text).Err()
488+
Expect(err).NotTo(HaveOccurred())
489+
490+
var msg *redis.Message
491+
Eventually(ch).Should(Receive(&msg))
492+
Expect(msg.Channel).To(Equal("mychannel"))
493+
Expect(msg.Payload).To(Equal(text))
494+
})
476495
})

0 commit comments

Comments
 (0)