diff --git a/zk/dnshostprovider.go b/zk/dnshostprovider.go index f4bba8d0..4060294d 100644 --- a/zk/dnshostprovider.go +++ b/zk/dnshostprovider.go @@ -4,24 +4,68 @@ import ( "fmt" "net" "sync" + "time" ) +// lookupInterval is the interval of retrying DNS lookup for unresolved hosts +const lookupInterval = time.Minute * 3 + // DNSHostProvider is the default HostProvider. It currently matches // the Java StaticHostProvider, resolving hosts from DNS once during // the call to Init. It could be easily extended to re-query DNS // periodically or if there is trouble connecting. type DNSHostProvider struct { - mu sync.Mutex // Protects everything, so we can add asynchronous updates later. - servers []string - curr int - last int - lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. + sleep func(time.Duration) // Override of time.Sleep, for testing. + + mu sync.Mutex // Protects everything, so we can add asynchronous updates later. + servers []string + unresolvedServers map[string]struct{} + curr int + last int + lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. } // Init is called first, with the servers specified in the connection // string. It uses DNS to look up addresses for each server, then // shuffles them all together. func (hp *DNSHostProvider) Init(servers []string) error { + if hp.sleep == nil { + hp.sleep = time.Sleep + } + hp.servers = make([]string, 0, len(servers)) + hp.unresolvedServers = make(map[string]struct{}, len(servers)) + for _, server := range servers { + hp.unresolvedServers[server] = struct{}{} + } + + done, err := hp.lookupUnresolvedServers() + if err != nil { + return err + } + + // as long as any host resolved successfully, consider the connection as success + // but start a lookup loop until all servers are resolved and added to servers list + if !done { + go hp.lookupLoop() + } + + return nil +} + +// lookupLoop calls lookupUnresolvedServers in an infinite loop until all hosts are resolved +// should be called in a separate goroutine +func (hp *DNSHostProvider) lookupLoop() { + for { + if done, _ := hp.lookupUnresolvedServers(); done { + break + } + hp.sleep(lookupInterval) + } +} + +// lookupUnresolvedServers DNS lookup the hosts that not successfully resolved yet +// and add them to servers list +func (hp *DNSHostProvider) lookupUnresolvedServers() (bool, error) { hp.mu.Lock() defer hp.mu.Unlock() @@ -30,33 +74,37 @@ func (hp *DNSHostProvider) Init(servers []string) error { lookupHost = net.LookupHost } - found := []string{} - for _, server := range servers { + if len(hp.unresolvedServers) == 0 { + return true, nil + } + + found := make([]string, 0, len(hp.unresolvedServers)) + for server := range hp.unresolvedServers { host, port, err := net.SplitHostPort(server) if err != nil { - return err + return false, err } addrs, err := lookupHost(host) if err != nil { - return err + continue } + delete(hp.unresolvedServers, server) for _, addr := range addrs { found = append(found, net.JoinHostPort(addr, port)) } } - - if len(found) == 0 { - return fmt.Errorf("No hosts found for addresses %q", servers) - } - // Randomize the order of the servers to avoid creating hotspots stringShuffle(found) - hp.servers = found + hp.servers = append(hp.servers, found...) hp.curr = -1 hp.last = -1 - return nil + if len(hp.servers) == 0 { + return true, fmt.Errorf("No hosts found for addresses %q", hp.servers) + } + + return false, nil } // Len returns the number of servers available diff --git a/zk/dnshostprovider_test.go b/zk/dnshostprovider_test.go index 77a60658..2dd5b530 100644 --- a/zk/dnshostprovider_test.go +++ b/zk/dnshostprovider_test.go @@ -1,6 +1,7 @@ package zk import ( + "errors" "fmt" "log" "testing" @@ -165,6 +166,42 @@ func TestDNSHostProviderReconnect(t *testing.T) { } } +// TestDNSHostOneHostDead tests whether +func TestDNSHostOneHostDead(t *testing.T) { + hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) { + if host == "foo.failure.com" { + return nil, errors.New("Fails to ns lookup") + } + return []string{"192.0.2.1", "192.0.2.2"}, nil + }, sleep: func(_ time.Duration) {}} + + if err := hp.Init([]string{"foo.failure.com:12345", "foo.success.com:12345"}); err != nil { + t.Fatal(err) + } + + hp.mu.Lock() + if len(hp.servers) != 2 { + t.Fatal("Only servers that resolved by lookupHost should be in servers list") + } + + // update lookupHost to mock a successful lookup + hp.lookupHost = func(host string) ([]string, error) { + if host == "foo.failure.com" { + return []string{"192.0.2.3"}, nil + } + return []string{"192.0.2.1", "192.0.2.2"}, nil + } + hp.mu.Unlock() + + time.Sleep(time.Millisecond * 5) + + hp.mu.Lock() + if len(hp.servers) != 3 { + t.Fatal("Servers get back online should be added to the servers list") + } + hp.mu.Unlock() +} + // TestDNSHostProviderRetryStart tests the `retryStart` functionality // of DNSHostProvider. // It's also probably the clearest visual explanation of exactly how