Skip to content

Commit 543ed3a

Browse files
authored
fix(pool): fix race condition with small pool sizes (#83)
* fix(dispatcher): some tasks are misse * double test runs * revert initial linkedbuffer size * remove brittle assertion * Revert dispatcher chan size * comment buffer reset * add lock while dispatching tasks to avoid deadlocks
1 parent f1d2a44 commit 543ed3a

File tree

7 files changed

+99
-31
lines changed

7 files changed

+99
-31
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
go-version: ${{ matrix.go-version }}
2424

2525
- name: Test
26-
run: make test
26+
run: make test-ci
2727
codecov:
2828
name: Coverage report
2929
runs-on: ubuntu-latest

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
test:
2-
go test -race -v -timeout 1m ./...
2+
go test -race -v -timeout 15s -count=1 ./...
3+
4+
test-ci:
5+
go test -race -v -timeout 1m -count=3 ./...
36

47
coverage:
58
go test -race -v -timeout 1m -coverprofile=coverage.out -covermode=atomic ./...

group_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ func TestResultTaskGroupWait(t *testing.T) {
3636

3737
func TestResultTaskGroupWaitWithError(t *testing.T) {
3838

39-
group := NewResultPool[int](1).
40-
NewGroup()
39+
pool := NewResultPool[int](1)
40+
41+
group := pool.NewGroup()
4142

4243
sampleErr := errors.New("sample error")
4344

internal/dispatcher/dispatcher.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ func NewDispatcher[T any](ctx context.Context, dispatchFunc func([]T), batchSize
3030
bufferHasElements: make(chan struct{}, 1),
3131
dispatchFunc: dispatchFunc,
3232
batchSize: batchSize,
33-
closed: atomic.Bool{},
3433
}
3534

3635
dispatcher.waitGroup.Add(1)
@@ -118,9 +117,6 @@ func (d *Dispatcher[T]) run(ctx context.Context) {
118117

119118
// Submit the next batch of values
120119
d.dispatchFunc(batch[0:batchSize])
121-
122-
// Reset batch
123-
batch = batch[:0]
124120
}
125121

126122
if !ok || d.closed.Load() {

internal/linkedbuffer/linkedbuffer.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type LinkedBuffer[T any] struct {
1717
maxCapacity int
1818
writeCount atomic.Uint64
1919
readCount atomic.Uint64
20-
mutex sync.RWMutex
20+
mutex sync.Mutex
2121
}
2222

2323
func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] {
@@ -78,28 +78,25 @@ func (b *LinkedBuffer[T]) Write(values []T) {
7878

7979
// Read reads values from the buffer and returns the number of elements read
8080
func (b *LinkedBuffer[T]) Read(values []T) int {
81+
b.mutex.Lock()
82+
defer b.mutex.Unlock()
8183

8284
var readBuffer *Buffer[T]
8385

8486
for {
85-
b.mutex.RLock()
8687
readBuffer = b.readBuffer
87-
b.mutex.RUnlock()
8888

8989
// Read element
9090
n, err := readBuffer.Read(values)
9191

9292
if err == ErrEOF {
9393
// Move to next buffer
94-
b.mutex.Lock()
9594
if readBuffer.next == nil {
96-
b.mutex.Unlock()
9795
return n
9896
}
9997
if b.readBuffer != readBuffer.next {
10098
b.readBuffer = readBuffer.next
10199
}
102-
b.mutex.Unlock()
103100
continue
104101
}
105102

internal/linkedbuffer/linkedbuffer_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,42 @@ func TestLinkedBufferLen(t *testing.T) {
9191
buf.readCount.Add(1)
9292
assert.Equal(t, uint64(0), buf.Len())
9393
}
94+
95+
func TestLinkedBufferWithReusedBuffer(t *testing.T) {
96+
97+
buf := NewLinkedBuffer[int](2, 1)
98+
99+
values := make([]int, 1)
100+
101+
buf.Write([]int{1})
102+
buf.Write([]int{2})
103+
104+
n := buf.Read(values)
105+
106+
assert.Equal(t, 1, n)
107+
assert.Equal(t, 1, values[0])
108+
109+
assert.Equal(t, 1, len(values))
110+
assert.Equal(t, 1, cap(values))
111+
112+
n = buf.Read(values)
113+
114+
assert.Equal(t, 1, n)
115+
assert.Equal(t, 1, len(values))
116+
assert.Equal(t, 2, values[0])
117+
118+
buf.Write([]int{3})
119+
buf.Write([]int{4})
120+
121+
n = buf.Read(values)
122+
123+
assert.Equal(t, 1, n)
124+
assert.Equal(t, 1, len(values))
125+
assert.Equal(t, 3, values[0])
126+
127+
n = buf.Read(values)
128+
129+
assert.Equal(t, 1, n)
130+
assert.Equal(t, 1, len(values))
131+
assert.Equal(t, 4, values[0])
132+
}

pool.go

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414

1515
var MAX_TASKS_CHAN_LENGTH = runtime.NumCPU() * 128
1616

17+
var PERSISTENT_WORKER_COUNT = int64(runtime.NumCPU())
18+
1719
var ErrPoolStopped = errors.New("pool stopped")
1820

1921
var poolStoppedFuture = func() Task {
@@ -91,6 +93,7 @@ type pool struct {
9193
workerCount atomic.Int64
9294
workerWaitGroup sync.WaitGroup
9395
dispatcher *dispatcher.Dispatcher[any]
96+
dispatcherRunning sync.Mutex
9497
successfulTaskCount atomic.Uint64
9598
failedTaskCount atomic.Uint64
9699
}
@@ -196,15 +199,16 @@ func (p *pool) NewGroupContext(ctx context.Context) TaskGroup {
196199
}
197200

198201
func (p *pool) dispatch(incomingTasks []any) {
202+
p.dispatcherRunning.Lock()
203+
defer p.dispatcherRunning.Unlock()
204+
199205
// Submit tasks
200206
for _, task := range incomingTasks {
201207
p.dispatchTask(task)
202208
}
203209
}
204210

205211
func (p *pool) dispatchTask(task any) {
206-
workerCount := int(p.workerCount.Load())
207-
208212
// Attempt to submit task without blocking
209213
select {
210214
case p.tasks <- task:
@@ -214,19 +218,13 @@ func (p *pool) dispatchTask(task any) {
214218
// 1. There are no idle workers (all spawned workers are processing a task)
215219
// 2. There are no workers in the pool
216220
// In either case, we should launch a new worker as long as the number of workers is less than the size of the task queue.
217-
if workerCount < p.tasksLen {
218-
// Launch a new worker
219-
p.startWorker()
220-
}
221+
p.startWorker(p.tasksLen)
221222
return
222223
default:
223224
}
224225

225226
// Task queue is full, launch a new worker if the number of workers is less than the maximum concurrency
226-
if workerCount < p.maxConcurrency {
227-
// Launch a new worker
228-
p.startWorker()
229-
}
227+
p.startWorker(p.maxConcurrency)
230228

231229
// Block until task is submitted
232230
select {
@@ -238,15 +236,41 @@ func (p *pool) dispatchTask(task any) {
238236
}
239237
}
240238

241-
func (p *pool) startWorker() {
239+
func (p *pool) startWorker(limit int) {
240+
if p.workerCount.Load() >= int64(limit) {
241+
return
242+
}
242243
p.workerWaitGroup.Add(1)
243-
p.workerCount.Add(1)
244-
go p.worker()
244+
workerNumber := p.workerCount.Add(1)
245+
// Guarantee at least PERSISTENT_WORKER_COUNT workers are always running during dispatch to prevent deadlocks
246+
canExitDuringDispatch := workerNumber > PERSISTENT_WORKER_COUNT
247+
go p.worker(canExitDuringDispatch)
245248
}
246249

247-
func (p *pool) worker() {
248-
defer func() {
250+
func (p *pool) workerCanExit(canExitDuringDispatch bool) bool {
251+
if canExitDuringDispatch {
249252
p.workerCount.Add(-1)
253+
return true
254+
}
255+
256+
// Check if the dispatcher is running
257+
if !p.dispatcherRunning.TryLock() {
258+
// Dispatcher is running, cannot exit yet
259+
return false
260+
}
261+
if len(p.tasks) > 0 {
262+
// There are tasks in the queue, cannot exit yet
263+
p.dispatcherRunning.Unlock()
264+
return false
265+
}
266+
p.workerCount.Add(-1)
267+
p.dispatcherRunning.Unlock()
268+
269+
return true
270+
}
271+
272+
func (p *pool) worker(canExitDuringDispatch bool) {
273+
defer func() {
250274
p.workerWaitGroup.Done()
251275
}()
252276

@@ -255,17 +279,20 @@ func (p *pool) worker() {
255279
select {
256280
case <-p.ctx.Done():
257281
// Context cancelled, exit
282+
p.workerCount.Add(-1)
258283
return
259284
default:
260285
}
261286

262287
select {
263288
case <-p.ctx.Done():
264289
// Context cancelled, exit
290+
p.workerCount.Add(-1)
265291
return
266292
case task, ok := <-p.tasks:
267293
if !ok || task == nil {
268294
// Channel closed or worker killed, exit
295+
p.workerCount.Add(-1)
269296
return
270297
}
271298

@@ -276,8 +303,13 @@ func (p *pool) worker() {
276303
p.updateMetrics(err)
277304

278305
default:
279-
// No tasks left, exit
280-
return
306+
// No tasks left
307+
308+
// Check if the worker can exit
309+
if p.workerCanExit(canExitDuringDispatch) {
310+
return
311+
}
312+
continue
281313
}
282314
}
283315
}

0 commit comments

Comments
 (0)