Skip to content

Commit 4c41dd8

Browse files
authored
Add unicast pattern (#7)
1 parent 903e85f commit 4c41dd8

File tree

6 files changed

+426
-69
lines changed

6 files changed

+426
-69
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- uses: actions/checkout@v4
1414
- uses: actions/setup-go@v5
1515
with:
16-
go-version: '1.21.6'
16+
go-version: '1.23.4'
1717
cache: false
1818
- name: Run linter
1919
working-directory: ./
@@ -26,7 +26,7 @@ jobs:
2626
- uses: actions/checkout@v4
2727
- uses: actions/setup-go@v5
2828
with:
29-
go-version: '1.21.6'
29+
go-version: '1.23.4'
3030
cache: false
3131
- name: Run tests
3232
working-directory: ./
@@ -39,7 +39,7 @@ jobs:
3939
- uses: actions/checkout@v4
4040
- uses: actions/setup-go@v5
4141
with:
42-
go-version: '1.21.6'
42+
go-version: '1.23.4'
4343
cache: false
4444
- name: Run code vetter
4545
working-directory: ./

README.md

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,44 @@ func main() {
5454
// all handler contexts will inherit from this context
5555
rootCtx := context.Background()
5656

57+
// send query to one grpc server
58+
dispatcher.Unicast(rootCtx, "node-name", func(ctx context.Context, conn *grpc.ClientConn) {
59+
// init grpc client
60+
client := examplepb.NewExampleServiceClient(conn)
61+
62+
// execute grpc request
63+
resp, err := client.Echo(ctx, &examplepb.EchoRequest{Message: "hello"})
64+
if err != nil {
65+
// do something with error
66+
fmt.Println(err)
67+
return
68+
}
69+
70+
// do something with response
71+
fmt.Println(resp)
72+
})
73+
74+
// send query to one grpc server and future servers at same node
75+
unicastSub, err := dispatcher.UnicastSubscribe(rootCtx, "node-name", func(ctx context.Context, conn *grpc.ClientConn) error {
76+
// init grpc client
77+
client := examplepb.NewExampleServiceClient(conn)
78+
79+
// execute grpc request
80+
resp, err := client.Echo(ctx, &examplepb.EchoRequest{Message: "hello"})
81+
if err != nil {
82+
// do something with error
83+
fmt.Println(err)
84+
return
85+
}
86+
87+
// do something with response
88+
fmt.Println(resp)
89+
})
90+
if err != nil {
91+
panic(err)
92+
}
93+
defer unicastSub.Unsubscribe()
94+
5795
// send query to all current grpc servers
5896
dispatcher.Fanout(rootCtx, func(ctx context.Context, conn *grpc.ClientConn) {
5997
// init grpc client
@@ -72,7 +110,7 @@ func main() {
72110
})
73111

74112
// send query to all current and future grpc servers
75-
sub, err := dispatcher.FanoutSubscribe(rootCtx, func(ctx context.Context, conn *grpc.ClientConn) error {
113+
fanoutSub, err := dispatcher.FanoutSubscribe(rootCtx, func(ctx context.Context, conn *grpc.ClientConn) error {
76114
// init grpc client
77115
client := examplepb.NewExampleServiceClient(conn)
78116

@@ -90,7 +128,7 @@ func main() {
90128
if err != nil {
91129
panic(err)
92130
}
93-
defer sub.Unsubscribe()
131+
defer fanoutSub.Unsubscribe()
94132
}
95133
```
96134

dispatcher.go

Lines changed: 134 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,24 @@ func init() {
5353
balancer.Register(builder)
5454
}
5555

56+
// Represents gRPC server internally
57+
type server struct {
58+
ip string
59+
nodeName string
60+
}
61+
5662
// Represents the callback argument to the dispatch methods
5763
type DispatchHandler func(ctx context.Context, conn *grpc.ClientConn)
5864

5965
// Represents interest in pod ips that are part of a Kubernetes service
6066
type Subscription struct {
61-
ipCh chan string
62-
cleanup func()
67+
serverCh chan server
68+
cleanup func()
6369
}
6470

6571
// Ends subscription
6672
func (sub *Subscription) Unsubscribe() {
67-
close(sub.ipCh)
73+
close(sub.serverCh)
6874
sub.cleanup()
6975
}
7076

@@ -78,22 +84,108 @@ type Dispatcher struct {
7884
informerReg cache.ResourceEventHandlerRegistration
7985
resolver *manual.Resolver
8086
conn *grpc.ClientConn
81-
ips mapset.Set[string]
87+
servers mapset.Set[server]
8288
mu sync.Mutex
8389
eventbus eventbus.Bus
8490
stopCh chan struct{}
8591
}
8692

93+
// Sends query to matching server at query-time
94+
func (d *Dispatcher) Unicast(ctx context.Context, nodeName string, fn DispatchHandler) {
95+
d.mu.Lock()
96+
currentServers := d.servers.ToSlice()
97+
d.mu.Unlock()
98+
99+
// Get ip for a server at `nodeName`
100+
var ip string
101+
for _, server := range currentServers {
102+
if server.nodeName == nodeName {
103+
ip = server.ip
104+
}
105+
}
106+
107+
// Exit if server not found
108+
if ip == "" {
109+
return
110+
}
111+
112+
doneCh := make(chan struct{})
113+
go func() {
114+
defer close(doneCh)
115+
connCtx := context.WithValue(ctx, dispatcherAddrCtxKey, fmt.Sprintf("%s:%s", ip, d.connectArgs.Port))
116+
fn(connCtx, d.conn)
117+
}()
118+
119+
// wait for func or ctx to finish whichever comes first
120+
select {
121+
case <-ctx.Done():
122+
case <-doneCh:
123+
}
124+
}
125+
126+
// Sends query to matching server at query-time and all subsequent servers when
127+
// they become available until Unsubscribe() is called
128+
func (d *Dispatcher) UnicastSubscribe(ctx context.Context, nodeName string, fn DispatchHandler) (*Subscription, error) {
129+
serverCh := make(chan server)
130+
131+
// server handler
132+
handleNewServers := func(newServers []server) {
133+
for _, server := range newServers {
134+
if server.nodeName == nodeName {
135+
serverCh <- server
136+
}
137+
}
138+
}
139+
140+
// worker
141+
go func() {
142+
for {
143+
select {
144+
case <-ctx.Done():
145+
return
146+
case server, ok := <-serverCh:
147+
if !ok {
148+
// unsubscribe was called
149+
return
150+
}
151+
152+
// execute dispatch handler in goroutine
153+
connCtx := context.WithValue(ctx, dispatcherAddrCtxKey, fmt.Sprintf("%s:%s", server.ip, d.connectArgs.Port))
154+
go fn(connCtx, d.conn)
155+
}
156+
}
157+
}()
158+
159+
// get current ips and subscribe to new ones in a lock
160+
d.mu.Lock()
161+
currentServers := d.servers.ToSlice()
162+
err := d.eventbus.SubscribeAsync("add:servers", handleNewServers, false)
163+
if err != nil {
164+
d.mu.Unlock()
165+
return nil, err
166+
}
167+
d.mu.Unlock()
168+
169+
handleNewServers(currentServers)
170+
171+
return &Subscription{
172+
serverCh: serverCh,
173+
cleanup: func() {
174+
d.eventbus.Unsubscribe("add:servers", handleNewServers)
175+
},
176+
}, nil
177+
}
178+
87179
// Sends queries to all available ips at query-time
88180
func (d *Dispatcher) Fanout(ctx context.Context, fn DispatchHandler) {
89181
var wg sync.WaitGroup
90182

91183
d.mu.Lock()
92-
ips := d.ips.ToSlice()
184+
servers := d.servers.ToSlice()
93185
d.mu.Unlock()
94186

95-
for _, ip := range ips {
96-
connCtx := context.WithValue(ctx, dispatcherAddrCtxKey, fmt.Sprintf("%s:%s", ip, d.connectArgs.Port))
187+
for _, server := range servers {
188+
connCtx := context.WithValue(ctx, dispatcherAddrCtxKey, fmt.Sprintf("%s:%s", server.ip, d.connectArgs.Port))
97189
wg.Add(1)
98190
go func(lclCtx context.Context) {
99191
defer wg.Done()
@@ -117,12 +209,12 @@ func (d *Dispatcher) Fanout(ctx context.Context, fn DispatchHandler) {
117209
// Sends queries to all available ips at query-time and all subsequent ips when
118210
// they become available until Unsubscribe() is called
119211
func (d *Dispatcher) FanoutSubscribe(ctx context.Context, fn DispatchHandler) (*Subscription, error) {
120-
ipCh := make(chan string)
212+
serverCh := make(chan server)
121213

122-
// ip handler
123-
handleNewIps := func(newIps []string) {
124-
for _, ip := range newIps {
125-
ipCh <- ip
214+
// server handler
215+
handleNewServers := func(newServers []server) {
216+
for _, server := range newServers {
217+
serverCh <- server
126218
}
127219
}
128220

@@ -132,35 +224,35 @@ func (d *Dispatcher) FanoutSubscribe(ctx context.Context, fn DispatchHandler) (*
132224
select {
133225
case <-ctx.Done():
134226
return
135-
case ip, ok := <-ipCh:
227+
case server, ok := <-serverCh:
136228
if !ok {
137229
// unsubscribe was called
138230
return
139231
}
140232

141233
// execute dispatch handler in goroutine
142-
connCtx := context.WithValue(ctx, dispatcherAddrCtxKey, fmt.Sprintf("%s:%s", ip, d.connectArgs.Port))
234+
connCtx := context.WithValue(ctx, dispatcherAddrCtxKey, fmt.Sprintf("%s:%s", server.ip, d.connectArgs.Port))
143235
go fn(connCtx, d.conn)
144236
}
145237
}
146238
}()
147239

148240
// get current ips and subscribe to new ones in a lock
149241
d.mu.Lock()
150-
currentIps := d.ips.ToSlice()
151-
err := d.eventbus.SubscribeAsync("add:addrs", handleNewIps, false)
242+
currentServers := d.servers.ToSlice()
243+
err := d.eventbus.SubscribeAsync("add:servers", handleNewServers, false)
152244
if err != nil {
153245
d.mu.Unlock()
154246
return nil, err
155247
}
156248
d.mu.Unlock()
157249

158-
handleNewIps(currentIps)
250+
handleNewServers(currentServers)
159251

160252
return &Subscription{
161-
ipCh: ipCh,
253+
serverCh: serverCh,
162254
cleanup: func() {
163-
d.eventbus.Unsubscribe("add:addrs", handleNewIps)
255+
d.eventbus.Unsubscribe("add:servers", handleNewServers)
164256
},
165257
}, nil
166258
}
@@ -210,14 +302,14 @@ func (d *Dispatcher) Shutdown() error {
210302

211303
// Handle add
212304
func (d *Dispatcher) handleAddEndpointSlice(es *discoveryv1.EndpointSlice) {
213-
newIps := getIpsFromEndpointSlice(es)
214-
d.updateState(newIps, nil)
305+
newServers := getServersFromEndpointSlice(es)
306+
d.updateState(newServers, nil)
215307
}
216308

217309
// Handle updates
218310
func (d *Dispatcher) handleUpdateEndpointSlice(esOld *discoveryv1.EndpointSlice, esNew *discoveryv1.EndpointSlice) {
219-
oldIps := mapset.NewSet(getIpsFromEndpointSlice(esOld)...)
220-
newIps := mapset.NewSet(getIpsFromEndpointSlice(esNew)...)
311+
oldIps := mapset.NewSet(getServersFromEndpointSlice(esOld)...)
312+
newIps := mapset.NewSet(getServersFromEndpointSlice(esNew)...)
221313

222314
toDelete := oldIps.Difference(newIps)
223315
toAdd := newIps.Difference(oldIps)
@@ -226,16 +318,16 @@ func (d *Dispatcher) handleUpdateEndpointSlice(esOld *discoveryv1.EndpointSlice,
226318
}
227319

228320
// Adds and deletes ips, updates clientconn state, publishes change to eventbus
229-
func (d *Dispatcher) updateState(toAdd []string, toDelete []string) {
321+
func (d *Dispatcher) updateState(toAdd []server, toDelete []server) {
230322
d.mu.Lock()
231323

232324
// update local state
233325
if len(toDelete) > 0 {
234-
d.ips.RemoveAll(toDelete...)
326+
d.servers.RemoveAll(toDelete...)
235327
}
236328

237329
if len(toAdd) > 0 {
238-
d.ips.Append(toAdd...)
330+
d.servers.Append(toAdd...)
239331
}
240332

241333
// exit if no changes
@@ -245,11 +337,11 @@ func (d *Dispatcher) updateState(toAdd []string, toDelete []string) {
245337
}
246338

247339
// update clientconn state
248-
ips := d.ips.ToSlice()
340+
servers := d.servers.ToSlice()
249341

250-
addrs := make([]resolver.Address, len(ips))
251-
for i, ip := range ips {
252-
addrs[i] = resolver.Address{Addr: fmt.Sprintf("%s:%s", ip, d.connectArgs.Port)}
342+
addrs := make([]resolver.Address, len(servers))
343+
for i, server := range servers {
344+
addrs[i] = resolver.Address{Addr: fmt.Sprintf("%s:%s", server.ip, d.connectArgs.Port)}
253345
}
254346

255347
d.resolver.UpdateState(resolver.State{Addresses: addrs})
@@ -258,7 +350,7 @@ func (d *Dispatcher) updateState(toAdd []string, toDelete []string) {
258350

259351
// publish change
260352
if len(toAdd) > 0 {
261-
d.eventbus.Publish("add:addrs", toAdd)
353+
d.eventbus.Publish("add:servers", toAdd)
262354
}
263355
}
264356

@@ -331,7 +423,7 @@ func NewDispatcher(connectUrl string, options ...DispatcherOption) (*Dispatcher,
331423
informer: informer,
332424
resolver: resolver,
333425
conn: conn,
334-
ips: mapset.NewSet[string](),
426+
servers: mapset.NewSet[server](),
335427
eventbus: eventbus.New(),
336428
}, nil
337429
}
@@ -379,12 +471,18 @@ func parseConnectUrl(connectUrl string) (*connectArgs, error) {
379471
}, nil
380472
}
381473

382-
func getIpsFromEndpointSlice(es *discoveryv1.EndpointSlice) []string {
383-
var ips []string
474+
func getServersFromEndpointSlice(es *discoveryv1.EndpointSlice) []server {
475+
var servers []server
384476
for _, endpoint := range es.Endpoints {
385477
if *endpoint.Conditions.Serving {
386-
ips = append(ips, endpoint.Addresses...)
478+
for _, addr := range endpoint.Addresses {
479+
s := server{
480+
nodeName: *endpoint.NodeName,
481+
ip: addr,
482+
}
483+
servers = append(servers, s)
484+
}
387485
}
388486
}
389-
return ips
487+
return servers
390488
}

0 commit comments

Comments
 (0)