Skip to content

Commit c018d93

Browse files
committed
feat(pool): add support for dynamically adjusting max concurrency
1 parent e2470f0 commit c018d93

File tree

3 files changed

+193
-3
lines changed

3 files changed

+193
-3
lines changed

pool.go

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ const (
2020
)
2121

2222
var (
23-
ErrQueueFull = errors.New("queue is full")
24-
ErrQueueEmpty = errors.New("queue is empty")
25-
ErrPoolStopped = errors.New("pool stopped")
23+
ErrQueueFull = errors.New("queue is full")
24+
ErrQueueEmpty = errors.New("queue is empty")
25+
ErrPoolStopped = errors.New("pool stopped")
26+
ErrMaxConcurrencyReached = errors.New("max concurrency reached")
2627

2728
poolStoppedFuture = func() Task {
2829
future, resolve := future.NewFuture(context.Background())
@@ -73,6 +74,11 @@ type basePool interface {
7374

7475
// Returns true if the pool has been stopped or its context has been cancelled.
7576
Stopped() bool
77+
78+
// Resizes the pool by changing the maximum concurrency (number of workers) of the pool.
79+
// The new max concurrency must be greater than 0.
80+
// If the new max concurrency is less than the current number of running workers, the pool will continue to run with the new max concurrency.
81+
Resize(maxConcurrency int)
7682
}
7783

7884
// Represents a pool of goroutines that can execute tasks concurrently.
@@ -125,9 +131,46 @@ func (p *pool) Stopped() bool {
125131
}
126132

127133
func (p *pool) MaxConcurrency() int {
134+
p.mutex.Lock()
135+
defer p.mutex.Unlock()
136+
128137
return p.maxConcurrency
129138
}
130139

140+
func (p *pool) Resize(maxConcurrency int) {
141+
if maxConcurrency <= 0 {
142+
panic(errors.New("maxConcurrency must be greater than 0"))
143+
}
144+
145+
p.mutex.Lock()
146+
defer p.mutex.Unlock()
147+
148+
delta := maxConcurrency - p.maxConcurrency
149+
150+
p.maxConcurrency = maxConcurrency
151+
152+
if delta > 0 {
153+
// Increase the number of workers by delta if there are tasks in the queue
154+
for i := 0; i < delta; i++ {
155+
if poppedTask, _ := p.tasks.Read(); poppedTask != nil {
156+
157+
p.workerCount.Add(1)
158+
159+
if p.parent == nil {
160+
p.workerWaitGroup.Add(1)
161+
// Launch a new worker
162+
go p.worker(poppedTask)
163+
} else {
164+
// Submit task to the parent pool
165+
p.subpoolSubmit(poppedTask)
166+
}
167+
} else {
168+
return
169+
}
170+
}
171+
}
172+
}
173+
131174
func (p *pool) QueueSize() int {
132175
return p.queueSize
133176
}
@@ -338,6 +381,15 @@ func (p *pool) readTask() (task any, err error) {
338381
return
339382
}
340383

384+
if p.maxConcurrency > 0 && int(p.workerCount.Load()) > p.maxConcurrency {
385+
// Max concurrency reached, kill the worker
386+
p.workerCount.Add(-1)
387+
p.mutex.Unlock()
388+
389+
err = ErrMaxConcurrencyReached
390+
return
391+
}
392+
341393
task, _ = p.tasks.Read()
342394

343395
p.mutex.Unlock()

pool_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,68 @@ func TestPoolWithQueueSizeAndNonBlocking(t *testing.T) {
274274

275275
pool.Stop().Wait()
276276
}
277+
278+
func TestPoolResize(t *testing.T) {
279+
280+
pool := NewPool(1, WithQueueSize(10))
281+
282+
assert.Equal(t, 1, pool.MaxConcurrency())
283+
284+
taskStarted := make(chan struct{}, 10)
285+
taskWait := make(chan struct{}, 10)
286+
287+
// Submit 10 tasks
288+
for i := 0; i < 10; i++ {
289+
pool.Submit(func() {
290+
<-taskStarted
291+
<-taskWait
292+
})
293+
}
294+
295+
// Unblock 3 tasks
296+
for i := 0; i < 3; i++ {
297+
taskStarted <- struct{}{}
298+
}
299+
300+
// Verify only 1 task is running and 9 are waiting
301+
time.Sleep(10 * time.Millisecond)
302+
assert.Equal(t, uint64(9), pool.WaitingTasks())
303+
assert.Equal(t, int64(1), pool.RunningWorkers())
304+
305+
// Increase max concurrency to 3
306+
pool.Resize(3)
307+
assert.Equal(t, 3, pool.MaxConcurrency())
308+
309+
// Unblock 3 more tasks
310+
for i := 0; i < 3; i++ {
311+
taskStarted <- struct{}{}
312+
}
313+
314+
// Verify 3 tasks are running and 7 are waiting
315+
time.Sleep(10 * time.Millisecond)
316+
assert.Equal(t, uint64(7), pool.WaitingTasks())
317+
assert.Equal(t, int64(3), pool.RunningWorkers())
318+
319+
// Decrease max concurrency to 1
320+
pool.Resize(2)
321+
assert.Equal(t, 2, pool.MaxConcurrency())
322+
323+
// Complete the 3 running tasks
324+
for i := 0; i < 3; i++ {
325+
taskWait <- struct{}{}
326+
}
327+
328+
// Unblock all remaining tasks
329+
for i := 0; i < 4; i++ {
330+
taskStarted <- struct{}{}
331+
}
332+
333+
// Ensure 2 tasks are running and 5 are waiting
334+
time.Sleep(10 * time.Millisecond)
335+
assert.Equal(t, uint64(5), pool.WaitingTasks())
336+
assert.Equal(t, int64(2), pool.RunningWorkers())
337+
338+
close(taskWait)
339+
340+
pool.Stop().Wait()
341+
}

subpool_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,76 @@ func TestSubpoolWithQueueSizeOverride(t *testing.T) {
334334
subpool.StopAndWait()
335335
pool.StopAndWait()
336336
}
337+
338+
func TestSubpoolResize(t *testing.T) {
339+
340+
parentPool := NewPool(10, WithQueueSize(10))
341+
342+
pool := parentPool.NewSubpool(1)
343+
344+
assert.Equal(t, 1, pool.MaxConcurrency())
345+
assert.Equal(t, 10, parentPool.MaxConcurrency())
346+
347+
taskStarted := make(chan struct{}, 10)
348+
taskWait := make(chan struct{}, 10)
349+
350+
// Submit 10 tasks
351+
for i := 0; i < 10; i++ {
352+
pool.Submit(func() {
353+
<-taskStarted
354+
<-taskWait
355+
})
356+
}
357+
358+
// Unblock 3 tasks
359+
for i := 0; i < 3; i++ {
360+
taskStarted <- struct{}{}
361+
}
362+
363+
// Verify only 1 task is running and 9 are waiting
364+
time.Sleep(10 * time.Millisecond)
365+
assert.Equal(t, uint64(9), pool.WaitingTasks())
366+
assert.Equal(t, int64(1), pool.RunningWorkers())
367+
assert.Equal(t, int64(1), parentPool.RunningWorkers())
368+
369+
// Increase max concurrency to 3
370+
pool.Resize(3)
371+
assert.Equal(t, 3, pool.MaxConcurrency())
372+
assert.Equal(t, 10, parentPool.MaxConcurrency())
373+
374+
// Unblock 3 more tasks
375+
for i := 0; i < 3; i++ {
376+
taskStarted <- struct{}{}
377+
}
378+
379+
// Verify 3 tasks are running and 7 are waiting
380+
time.Sleep(10 * time.Millisecond)
381+
assert.Equal(t, uint64(7), pool.WaitingTasks())
382+
assert.Equal(t, int64(3), pool.RunningWorkers())
383+
assert.Equal(t, int64(3), parentPool.RunningWorkers())
384+
385+
// Decrease max concurrency to 1
386+
pool.Resize(2)
387+
assert.Equal(t, 2, pool.MaxConcurrency())
388+
assert.Equal(t, 10, parentPool.MaxConcurrency())
389+
390+
// Complete the 3 running tasks
391+
for i := 0; i < 3; i++ {
392+
taskWait <- struct{}{}
393+
}
394+
395+
// Unblock all remaining tasks
396+
for i := 0; i < 4; i++ {
397+
taskStarted <- struct{}{}
398+
}
399+
400+
// Ensure 2 tasks are running and 5 are waiting
401+
time.Sleep(10 * time.Millisecond)
402+
assert.Equal(t, uint64(5), pool.WaitingTasks())
403+
assert.Equal(t, int64(2), pool.RunningWorkers())
404+
assert.Equal(t, int64(2), parentPool.RunningWorkers())
405+
406+
close(taskWait)
407+
408+
pool.Stop().Wait()
409+
}

0 commit comments

Comments
 (0)