Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start dns refactor #1239

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package dns
package coredns

import (
"reflect"
"unsafe"

"beryju.io/gravity/pkg/roles/dns/handlers"
"beryju.io/gravity/pkg/roles/dns/utils"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
Expand All @@ -23,11 +24,11 @@ type CoreDNS struct {
srv *dnsserver.Server
}

func NewCoreDNS(z *Zone, rawConfig map[string]string) *CoreDNS {
func NewCoreDNS(z handlers.HandlerZoneContext, rawConfig map[string]string) *CoreDNS {
core := &CoreDNS{
c: rawConfig,
}
core.log = z.log.With(zap.String("handler", core.Identifier()))
core.log = z.Log().With(zap.String("handler", core.Identifier()))
dnsserver.Quiet = true
corefile := caddy.CaddyfileInput{
Contents: []byte(core.c["config"]),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dns_test
package coredns_test

import (
"net"
Expand All @@ -19,6 +19,12 @@ const CoreDNSConfig = `.:1342 {
}
}`

func RoleConfig() []byte {
return []byte(tests.MustJSON(dns.RoleConfig{
Port: 1054,
}))
}

func TestRoleDNSHandlerCoreDNS(t *testing.T) {
defer tests.Setup(t)()
rootInst := instance.New()
Expand Down Expand Up @@ -46,7 +52,7 @@ func TestRoleDNSHandlerCoreDNS(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package dns
package etcd

import (
"strings"

"beryju.io/gravity/pkg/roles/dns/handlers"
"beryju.io/gravity/pkg/roles/dns/types"
"beryju.io/gravity/pkg/roles/dns/utils"
"beryju.io/gravity/pkg/storage"
Expand All @@ -15,28 +16,28 @@ import (
const EtcdType = "etcd"

type EtcdHandler struct {
log *zap.Logger
z *Zone
lookupKey func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR
log *zap.Logger
z handlers.HandlerZoneContext
LookupKeyFunc func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR
}

func NewEtcdHandler(z *Zone, config map[string]string) *EtcdHandler {
func NewEtcdHandler(z handlers.HandlerZoneContext, config map[string]string) *EtcdHandler {
eh := &EtcdHandler{
z: z,
}
eh.lookupKey = func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR {
eh.LookupKeyFunc = func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR {
answers := []dns.RR{}
es := sentry.TransactionFromContext(r.Context()).StartChild("gravity.dns.handler.etcd.get")
defer es.Finish()
key := k.String()
eh.log.Debug("fetching kv key", zap.String("key", key))
es.SetTag("gravity.dns.handler.etcd.key", key)
res, err := eh.z.inst.KV().Get(r.Context(), key, clientv3.WithPrefix())
res, err := eh.z.RoleInstance().KV().Get(r.Context(), key, clientv3.WithPrefix())
if err != nil || len(res.Kvs) < 1 {
return answers
}
for _, kv := range res.Kvs {
rec, err := eh.z.recordFromKV(kv)
rec, err := eh.z.RecordFromKV(kv)
if err != nil {
continue
}
Expand All @@ -47,7 +48,7 @@ func NewEtcdHandler(z *Zone, config map[string]string) *EtcdHandler {
}
return answers
}
eh.log = z.log.With(zap.String("handler", eh.Identifier()))
eh.log = z.Log().With(zap.String("handler", eh.Identifier()))
return eh
}

Expand All @@ -64,8 +65,8 @@ func (eh *EtcdHandler) findWildcard(r *utils.DNSRequest, relRecordName string, q
// Replace the current dot part with a wildcard (make sure to only replace 1 occurrence,
// since we replace from left to right)
wildcardName = strings.Replace(wildcardName, part, types.DNSWildcard, 1)
wildcardKey := eh.z.inst.KV().Key(eh.z.etcdKey, strings.ToLower(wildcardName), dns.Type(question.Qtype).String())
wildcardAns := eh.lookupKey(wildcardKey, question.Name, r)
wildcardKey := eh.z.RoleInstance().KV().Key(eh.z.EtcdKey(), strings.ToLower(wildcardName), dns.Type(question.Qtype).String())
wildcardAns := eh.LookupKeyFunc(wildcardKey, question.Name, r)
// If we do get an answer from this wildcard key, stop going further
if len(wildcardAns) > 0 {
return wildcardAns
Expand All @@ -86,8 +87,8 @@ func (eh *EtcdHandler) handleSingleQuestion(question dns.Question, r *utils.DNSR
// in the database
relRecordName = strings.TrimSuffix(relRecordName, ".")
}
directRecordKey := eh.z.inst.KV().Key(
eh.z.etcdKey,
directRecordKey := eh.z.RoleInstance().KV().Key(
eh.z.EtcdKey(),
strings.ToLower(relRecordName),
)
if question.Qtype != dns.TypeNone {
Expand All @@ -101,7 +102,7 @@ func (eh *EtcdHandler) handleSingleQuestion(question dns.Question, r *utils.DNSR
directRecordKey = directRecordKey.Prefix(true)
}
// Look for direct matches first
answers = append(answers, eh.lookupKey(
answers = append(answers, eh.LookupKeyFunc(
directRecordKey,
question.Name,
r,
Expand Down Expand Up @@ -137,9 +138,9 @@ func (eh *EtcdHandler) Handle(w *utils.FakeDNSWriter, r *utils.DNSRequest) *dns.
for un := range uniqueQuestionNames {
// Look for CNAMEs
relRecordName := strings.TrimSuffix(strings.ToLower(un), strings.ToLower(utils.EnsureLeadingPeriod(eh.z.Name)))
cnames := eh.lookupKey(
eh.z.inst.KV().Key(
eh.z.etcdKey,
cnames := eh.LookupKeyFunc(
eh.z.RoleInstance().KV().Key(
eh.z.EtcdKey(),
strings.ToLower(relRecordName),
dns.Type(dns.TypeCNAME).String(),
),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dns_test
package etcd_test

import (
"net"
Expand Down Expand Up @@ -54,7 +54,7 @@ func TestRoleDNS_Etcd(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestRoleDNS_Etcd_Root(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestRoleDNS_Etcd_Wildcard(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestRoleDNS_Etcd_CNAME(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand All @@ -244,7 +244,7 @@ func TestRoleDNS_Etcd_CNAME(t *testing.T) {
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.2.3.4").String(), ans.(*d.A).A.String())

fw = NewNullDNSWriter()
fw = dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -298,7 +298,7 @@ func TestRoleDNS_Etcd_WildcardNested(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -352,7 +352,7 @@ func TestRoleDNS_Etcd_MixedCase(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -406,7 +406,7 @@ func TestRoleDNS_Etcd_MixedCase_Reverse(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestRoleDNS_BlockyForwarder(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package dns
package forward_ip

import (
"math/rand"
"strconv"
"strings"

"beryju.io/gravity/pkg/roles/dns/handlers"
"beryju.io/gravity/pkg/roles/dns/types"
"beryju.io/gravity/pkg/roles/dns/utils"
"github.com/getsentry/sentry-go"
Expand All @@ -18,12 +19,12 @@ type IPForwarderHandler struct {
c *dns.Client
resolvers []string

z *Zone
z handlers.HandlerZoneContext
log *zap.Logger
CacheTTL int
}

func NewIPForwarderHandler(z *Zone, config map[string]string) *IPForwarderHandler {
func NewIPForwarderHandler(z handlers.HandlerZoneContext, config map[string]string) *IPForwarderHandler {
net, ok := config["net"]
if !ok {
net = ""
Expand All @@ -37,7 +38,7 @@ func NewIPForwarderHandler(z *Zone, config map[string]string) *IPForwarderHandle
},
resolvers: strings.Split(config["to"], ";"),
}
ipf.log = z.log.With(zap.String("handler", ipf.Identifier()))
ipf.log = z.Log().With(zap.String("handler", ipf.Identifier()))

rawTtl := config["cache_ttl"]
cacheTtl, err := strconv.Atoi(rawTtl)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dns_test
package forward_ip_test

import (
"net"
Expand Down Expand Up @@ -41,7 +41,7 @@ func TestRoleDNS_IPForwarder_v4(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestRoleDNS_IPForwarder_v4_Cache(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down Expand Up @@ -150,7 +150,7 @@ func TestRoleDNS_IPForwarder_v6(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

fw := NewNullDNSWriter()
fw := dns.NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Expand Down
27 changes: 27 additions & 0 deletions pkg/roles/dns/handlers/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package handlers

import (
"beryju.io/gravity/pkg/roles"
"beryju.io/gravity/pkg/roles/dns/utils"
"github.com/miekg/dns"
"go.etcd.io/etcd/api/v3/mvccpb"
"go.uber.org/zap"
)

type Handler interface {
Handle(w *utils.FakeDNSWriter, r *utils.DNSRequest) *dns.Msg
Identifier() string
}

type HandlerZoneContext interface {
Log() *zap.Logger
// // TODO Rename to Zone()
// GetZone() *types.Zone
RecordFromKV(kv *mvccpb.KeyValue) (HandlerRecord, error)
RoleInstance() roles.Instance
EtcdKey() string
}

type HandlerRecord interface {
ToDNS(string) dns.RR
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package dns
package memory

import (
"strings"

"beryju.io/gravity/pkg/roles/dns/handlers"
"beryju.io/gravity/pkg/roles/dns/handlers/etcd"
"beryju.io/gravity/pkg/roles/dns/utils"
"beryju.io/gravity/pkg/storage"
"github.com/miekg/dns"
Expand All @@ -12,17 +14,17 @@ import (
const MemoryType = "memory"

type MemoryHandler struct {
*EtcdHandler
*etcd.EtcdHandler
log *zap.Logger
z *Zone
z handlers.HandlerZoneContext
}

func NewMemoryHandler(z *Zone, config map[string]string) *MemoryHandler {
func NewMemoryHandler(z handlers.HandlerZoneContext, config map[string]string) *MemoryHandler {
mh := &MemoryHandler{
EtcdHandler: &EtcdHandler{z: z},
EtcdHandler: etcd.NewEtcdHandler(z, config),
z: z,
}
mh.lookupKey = func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR {
mh.LookupKeyFunc = func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR {
answers := []dns.RR{}
mh.z.recordsSync.RLock()
defer mh.z.recordsSync.RUnlock()
Expand Down
Loading