From fb52b7d0d23e339478faf3677e2cb5854c291c07 Mon Sep 17 00:00:00 2001 From: Yu Xie Date: Thu, 22 Mar 2018 18:43:13 -0700 Subject: [PATCH] Update DNS lookup behavior --- zk/dnshostprovider.go | 92 +++++++++++++++++++++++++++++--------- zk/dnshostprovider_test.go | 45 +++++++++++++++++++ 2 files changed, 115 insertions(+), 22 deletions(-) diff --git a/zk/dnshostprovider.go b/zk/dnshostprovider.go index f4bba8d0..d5828542 100644 --- a/zk/dnshostprovider.go +++ b/zk/dnshostprovider.go @@ -4,59 +4,107 @@ import ( "fmt" "net" "sync" + "time" ) +// lookupInterval is the interval of retrying DNS lookup for unresolved hosts +const lookupInterval = time.Minute * 3 + // DNSHostProvider is the default HostProvider. It currently matches // the Java StaticHostProvider, resolving hosts from DNS once during // the call to Init. It could be easily extended to re-query DNS // periodically or if there is trouble connecting. type DNSHostProvider struct { - mu sync.Mutex // Protects everything, so we can add asynchronous updates later. - servers []string - curr int - last int - lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. + // fields above mu are not thread safe + unresolvedServers map[string]struct{} + sleep func(time.Duration) // Override of time.Sleep, for testing. + lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. + + mu sync.Mutex // Protects everything below, so we can add asynchronous updates later. + servers []string + curr int + last int } // Init is called first, with the servers specified in the connection // string. It uses DNS to look up addresses for each server, then // shuffles them all together. func (hp *DNSHostProvider) Init(servers []string) error { - hp.mu.Lock() - defer hp.mu.Unlock() - - lookupHost := hp.lookupHost - if lookupHost == nil { - lookupHost = net.LookupHost + if hp.sleep == nil { + hp.sleep = time.Sleep + } + if hp.lookupHost == nil { + hp.lookupHost = net.LookupHost } - found := []string{} + hp.servers = make([]string, 0, len(servers)) + hp.unresolvedServers = make(map[string]struct{}, len(servers)) for _, server := range servers { + hp.unresolvedServers[server] = struct{}{} + } + + done, err := hp.lookupUnresolvedServers() + if err != nil { + return err + } + + // as long as any host resolved successfully, consider the connection as success + // but start a lookup loop until all servers are resolved and added to servers list + if !done { + go hp.lookupLoop() + } + + return nil +} + +// lookupLoop calls lookupUnresolvedServers in an infinite loop until all hosts are resolved +// should be called in a separate goroutine +func (hp *DNSHostProvider) lookupLoop() { + for { + if done, _ := hp.lookupUnresolvedServers(); done { + break + } + hp.sleep(lookupInterval) + } +} + +// lookupUnresolvedServers DNS lookup the hosts that not successfully resolved yet +// and add them to servers list +func (hp *DNSHostProvider) lookupUnresolvedServers() (bool, error) { + if len(hp.unresolvedServers) == 0 { + return true, nil + } + + found := make([]string, 0, len(hp.unresolvedServers)) + for server := range hp.unresolvedServers { host, port, err := net.SplitHostPort(server) if err != nil { - return err + return false, err } - addrs, err := lookupHost(host) + addrs, err := hp.lookupHost(host) if err != nil { - return err + continue } + delete(hp.unresolvedServers, server) for _, addr := range addrs { found = append(found, net.JoinHostPort(addr, port)) } } - - if len(found) == 0 { - return fmt.Errorf("No hosts found for addresses %q", servers) - } - // Randomize the order of the servers to avoid creating hotspots stringShuffle(found) - hp.servers = found + hp.mu.Lock() + defer hp.mu.Unlock() + + hp.servers = append(hp.servers, found...) hp.curr = -1 hp.last = -1 - return nil + if len(hp.servers) == 0 { + return true, fmt.Errorf("No hosts found for addresses %q", hp.servers) + } + + return false, nil } // Len returns the number of servers available diff --git a/zk/dnshostprovider_test.go b/zk/dnshostprovider_test.go index 77a60658..70eabb47 100644 --- a/zk/dnshostprovider_test.go +++ b/zk/dnshostprovider_test.go @@ -1,6 +1,7 @@ package zk import ( + "errors" "fmt" "log" "testing" @@ -165,6 +166,50 @@ func TestDNSHostProviderReconnect(t *testing.T) { } } +// TestDNSHostOneHostDead tests whether +func TestDNSHostOneHostDead(t *testing.T) { + // use channel to simulate a server that was initially dead but came back online later + ch := make(chan struct{}, 0) + hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) { + if host != "foo.failure.com" { + return []string{"192.0.2.1", "192.0.2.2"}, nil + } + select { + case <-ch: + return []string{"192.0.2.3"}, nil + default: + return nil, errors.New("Fails to ns lookup") + } + }, sleep: func(_ time.Duration) {}} + + if err := hp.Init([]string{"foo.failure.com:12345", "foo.success.com:12345"}); err != nil { + t.Fatal(err) + } + + hp.mu.Lock() + if len(hp.servers) != 2 { + t.Fatal("Only servers that resolved by lookupHost should be in servers list") + } + hp.mu.Unlock() + + // simulating one server comes back online + close(ch) + + // starts a 30s retry loop to wait servers list to be updated + startRetryLoop := time.Now() + for { + time.Sleep(time.Millisecond * 5) + hp.mu.Lock() + if len(hp.servers) == 3 { + break + } + hp.mu.Unlock() + if time.Since(startRetryLoop) > time.Second*30 { + t.Fatal("Servers get back online should be added to the servers list") + } + } +} + // TestDNSHostProviderRetryStart tests the `retryStart` functionality // of DNSHostProvider. // It's also probably the clearest visual explanation of exactly how