Skip to content
This repository has been archived by the owner on Jul 21, 2021. It is now read-only.

Commit

Permalink
Update DNS lookup behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu Xie committed Mar 23, 2018
1 parent c4fab1a commit dab8413
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 16 deletions.
80 changes: 64 additions & 16 deletions zk/dnshostprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,68 @@ import (
"fmt"
"net"
"sync"
"time"
)

// lookupInterval is the interval of retrying DNS lookup for unresolved hosts
const lookupInterval = time.Minute * 3

// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
sleep func(time.Duration) // Override of time.Sleep, for testing.

mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
unresolvedServers map[string]struct{}
curr int
last int
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
}

// Init is called first, with the servers specified in the connection
// string. It uses DNS to look up addresses for each server, then
// shuffles them all together.
func (hp *DNSHostProvider) Init(servers []string) error {
if hp.sleep == nil {
hp.sleep = time.Sleep
}
hp.servers = make([]string, 0, len(servers))
hp.unresolvedServers = make(map[string]struct{}, len(servers))
for _, server := range servers {
hp.unresolvedServers[server] = struct{}{}
}

done, err := hp.lookupUnresolvedServers()
if err != nil {
return err
}

// as long as any host resolved successfully, consider the connection as success
// but start a lookup loop until all servers are resolved and added to servers list
if !done {
go hp.lookupLoop()
}

return nil
}

// lookupLoop calls lookupUnresolvedServers in an infinite loop until all hosts are resolved
// should be called in a separate goroutine
func (hp *DNSHostProvider) lookupLoop() {
for {
if done, _ := hp.lookupUnresolvedServers(); done {
break
}
hp.sleep(lookupInterval)
}
}

// lookupUnresolvedServers DNS lookup the hosts that not successfully resolved yet
// and add them to servers list
func (hp *DNSHostProvider) lookupUnresolvedServers() (bool, error) {
hp.mu.Lock()
defer hp.mu.Unlock()

Expand All @@ -30,33 +74,37 @@ func (hp *DNSHostProvider) Init(servers []string) error {
lookupHost = net.LookupHost
}

found := []string{}
for _, server := range servers {
if len(hp.unresolvedServers) == 0 {
return true, nil
}

found := make([]string, 0, len(hp.unresolvedServers))
for server := range hp.unresolvedServers {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
return false, err
}
addrs, err := lookupHost(host)
if err != nil {
return err
continue
}
delete(hp.unresolvedServers, server)
for _, addr := range addrs {
found = append(found, net.JoinHostPort(addr, port))
}
}

if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
}

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)

hp.servers = found
hp.servers = append(hp.servers, found...)
hp.curr = -1
hp.last = -1

return nil
if len(hp.servers) == 0 {
return true, fmt.Errorf("No hosts found for addresses %q", hp.servers)
}

return false, nil
}

// Len returns the number of servers available
Expand Down
42 changes: 42 additions & 0 deletions zk/dnshostprovider_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zk

import (
"errors"
"fmt"
"log"
"testing"
Expand Down Expand Up @@ -165,6 +166,47 @@ func TestDNSHostProviderReconnect(t *testing.T) {
}
}

// TestDNSHostOneHostDead tests whether
func TestDNSHostOneHostDead(t *testing.T) {
hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
if host == "foo.failure.com" {
return nil, errors.New("Fails to ns lookup")
}
return []string{"192.0.2.1", "192.0.2.2"}, nil
}, sleep: func(_ time.Duration) {}}

if err := hp.Init([]string{"foo.failure.com:12345", "foo.success.com:12345"}); err != nil {
t.Fatal(err)
}

hp.mu.Lock()
if len(hp.servers) != 2 {
t.Fatal("Only servers that resolved by lookupHost should be in servers list")
}

// update lookupHost to mock a successful lookup
hp.lookupHost = func(host string) ([]string, error) {
if host == "foo.failure.com" {
return []string{"192.0.2.3"}, nil
}
return []string{"192.0.2.1", "192.0.2.2"}, nil
}
hp.mu.Unlock()

// Starts a 30s retry loop to wait servers list to be updated
startRetryLoop := time.Now()
for {
hp.mu.Lock()
if len(hp.servers) == 3 {
break
}
hp.mu.Unlock()
if time.Since(startRetryLoop) > time.Second * 30 {
t.Fatal("Servers get back online should be added to the servers list")
}
}
}

// TestDNSHostProviderRetryStart tests the `retryStart` functionality
// of DNSHostProvider.
// It's also probably the clearest visual explanation of exactly how
Expand Down

0 comments on commit dab8413

Please sign in to comment.