From db1efc8976eb58b5b1390a885d393a9c8f742011 Mon Sep 17 00:00:00 2001 From: Christian Wellenbrock Date: Mon, 29 Jan 2024 15:07:06 +0100 Subject: [PATCH] Stop heartbeat() in Connection.StopAllConsuming() Once all consumers of the connection have been stopped we also stop the heartbeat goroutine to avoid a goroutine leak. --- README.md | 3 +++ connection.go | 54 ++++++++++++++++++++++++++++++------------- errors.go | 2 +- queue_cluster_test.go | 32 +++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index e290932..df48650 100644 --- a/README.md +++ b/README.md @@ -412,6 +412,9 @@ use case where you actually need that sort of flexibility, please let us know. Currently for each queue you are only supposed to call `StartConsuming()` and `StopConsuming()` at most once. +Also note that `StopAllConsuming()` will stop the heartbeat for this connection. +It's advised to also not publish to any queue opened by this connection anymore. + ### Return Rejected Deliveries Even if you don't have a push queue setup there are cases where you need to diff --git a/connection.go b/connection.go index 055be0f..0247627 100644 --- a/connection.go +++ b/connection.go @@ -57,7 +57,7 @@ type redisConnection struct { redisClient RedisClient errChan chan<- error - heartbeatStop chan chan struct{} + heartbeatStop chan chan struct{} // used to stop heartbeat() in stopHeartbeat(), nil once stopped lock sync.Mutex stopped bool @@ -112,7 +112,7 @@ func openConnection(tag string, redisClient RedisClient, useRedisHashTags bool, rejectedTemplate: getTemplate(queueRejectedBaseTemplate, useRedisHashTags), redisClient: redisClient, errChan: errChan, - heartbeatStop: make(chan chan struct{}, 1), + heartbeatStop: make(chan chan struct{}, 1), // mark heartbeat as active, can be stopped } if err := connection.updateHeartbeat(); err != nil { // checks the connection @@ -144,9 +144,9 @@ func (connection *redisConnection) heartbeat(errChan chan<- error) { select { case <-ticker.C: // continue below - case c := <-connection.heartbeatStop: - close(c) - return + case c := <-connection.heartbeatStop: // stopHeartbeat() has been called + close(c) // confirm to stopHeartbeat() that the heartbeat is stopped + return // stop updating the heartbeat } err := connection.updateHeartbeat() @@ -160,7 +160,13 @@ func (connection *redisConnection) heartbeat(errChan chan<- error) { if errorCount >= HeartbeatErrorLimit { // reached error limit + + // To avoid using this connection while we're not able to maintain its heartbeat we stop all + // consumers. This in turn will call stopHeartbeat() and the responsibility of heartbeat() to + // confirm that the heartbeat is stopped, so we do that here too. connection.StopAllConsuming() + close(<-connection.heartbeatStop) // wait for stopHeartbeat() and confirm heartbeat is stopped + // Clients reading from errChan need to see this error // This allows them to shut themselves down // Therefore we block adding it to errChan to ensure delivery @@ -223,8 +229,15 @@ func (connection *redisConnection) StopAllConsuming() <-chan struct{} { finishedChan := make(chan struct{}) - // If we are already stopped or there are no open queues, then there is nothing to do - if connection.stopped || len(connection.openQueues) == 0 { + // If we are already stopped then there is nothing to do + if connection.stopped { + close(finishedChan) + return finishedChan + } + + // If there are no open queues we still want to stop the heartbeat + if len(connection.openQueues) == 0 { + connection.stopHeartbeat() close(finishedChan) return finishedChan } @@ -239,8 +252,11 @@ func (connection *redisConnection) StopAllConsuming() <-chan struct{} { for _, c := range chans { <-c } - close(finishedChan) - // log.Printf("rmq connection stopped consuming %s", queue) + + // All consuming has been stopped. Now we can stop the heartbeat to avoid a goroutine leak. + connection.stopHeartbeat() + + close(finishedChan) // signal all done }() return finishedChan @@ -317,23 +333,29 @@ func (connection *redisConnection) openQueue(name string) Queue { ) } -// stopHeartbeat stops the heartbeat of the connection -// it does not remove it from the list of connections so it can later be found by the cleaner +// stopHeartbeat stops the heartbeat of the connection. +// It does not remove it from the list of connections so it can later be found by the cleaner. +// Returns ErrorNotFound if the heartbeat was already stopped. +// Note that this function itself is not threadsafe, it's important to not call it multiple times +// at the same time. Currently it's only called in StopAllConsuming() where it's linearized by +// connection.lock. func (connection *redisConnection) stopHeartbeat() error { - if connection.heartbeatStop == nil { + if connection.heartbeatStop == nil { // already stopped return ErrorNotFound } heartbeatStopped := make(chan struct{}) connection.heartbeatStop <- heartbeatStopped - <-heartbeatStopped - connection.heartbeatStop = nil // avoid stopping twice + <-heartbeatStopped // wait for heartbeat() to confirm it's stopped + connection.heartbeatStop = nil // mark heartbeat as stopped + // Delete heartbeat key to immediately make the connection appear inactive to the cleaner, + // instead of waiting for the heartbeat key to run into its TTL. count, err := connection.redisClient.Del(connection.heartbeatKey) - if err != nil { + if err != nil { // redis error return err } - if count == 0 { + if count == 0 { // heartbeat key didn't exist return ErrorNotFound } return nil diff --git a/errors.go b/errors.go index d0c92f8..1382f94 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - ErrorNotFound = errors.New("entity not found") // entitify being connection/queue/delivery + ErrorNotFound = errors.New("entity not found") // entity being connection/queue/delivery/heartbeat ErrorAlreadyConsuming = errors.New("must not call StartConsuming() multiple times") ErrorNotConsuming = errors.New("must call StartConsuming() before adding consumers") ErrorConsumingStopped = errors.New("consuming stopped") diff --git a/queue_cluster_test.go b/queue_cluster_test.go index 1ce3de2..7516531 100644 --- a/queue_cluster_test.go +++ b/queue_cluster_test.go @@ -646,6 +646,29 @@ func TestClusterStopConsuming_BatchConsumer(t *testing.T) { assert.NoError(t, connection.stopHeartbeat()) } +func TestClusterConnection_StopAllConsuming_CalledTwice(t *testing.T) { + redisOptions, closer := testClusterRedis(t) + defer closer() + + connection, err := OpenClusterConnection("conn1", redis.NewClusterClient(redisOptions), nil) + assert.NoError(t, err) + + finishedChan := connection.StopAllConsuming() + require.NotNil(t, finishedChan) + <-finishedChan // wait for stopping to finish + + // check that heartbeat has been stopped + assert.Equal(t, connection.checkHeartbeat(), ErrorNotFound) + + // it's safe to call StopAllConsuming again + finishedChan = connection.StopAllConsuming() + require.NotNil(t, finishedChan) + <-finishedChan // wait for stopping to finish + + // heartbeat is still stopped of course + assert.Equal(t, connection.checkHeartbeat(), ErrorNotFound) +} + func TestClusterConnection_StopAllConsuming_CantOpenQueue(t *testing.T) { redisOptions, closer := testClusterRedis(t) defer closer() @@ -657,6 +680,9 @@ func TestClusterConnection_StopAllConsuming_CantOpenQueue(t *testing.T) { require.NotNil(t, finishedChan) <-finishedChan // wait for stopping to finish + // check that heartbeat has been stopped + assert.Equal(t, connection.checkHeartbeat(), ErrorNotFound) + queue, err := connection.OpenQueue("consume-q") require.Nil(t, queue) require.Equal(t, ErrorConsumingStopped, err) @@ -677,6 +703,9 @@ func TestClusterConnection_StopAllConsuming_CantStartConsuming(t *testing.T) { require.NotNil(t, finishedChan) <-finishedChan // wait for stopping to finish + // check that heartbeat has been stopped + assert.Equal(t, connection.checkHeartbeat(), ErrorNotFound) + err = queue.StartConsuming(20, time.Millisecond) require.Equal(t, ErrorConsumingStopped, err) } @@ -717,6 +746,9 @@ func TestClusterConnection_StopAllConsuming_CantAddConsumer(t *testing.T) { require.NotNil(t, finishedChan) <-finishedChan // wait for stopping to finish + // check that heartbeat has been stopped + assert.Equal(t, connection.checkHeartbeat(), ErrorNotFound) + _, err = queue.AddConsumer("late-consume", NewTestConsumer("late-consumer")) require.Equal(t, ErrorConsumingStopped, err) }