Skip to content

Commit

Permalink
Merge pull request #67 from telekom-mms/feature/rework-answer-handlin…
Browse files Browse the repository at this point in the history
…g-in-dnsproxy

Feature/rework answer handling in dnsproxy
  • Loading branch information
hwipl authored Apr 10, 2024
2 parents 2230f0d + f61ab61 commit 99d2ecd
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 125 deletions.
127 changes: 72 additions & 55 deletions internal/dnsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,65 +47,82 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
}

// parse answers in reply from remote server
for _, a := range reply.Answer {
name := a.Header().Name
if !p.watches.Contains(r.Question[0].Name) &&
!p.watches.Contains(name) {
// not on watch list, ignore answer
continue
// handler for DNAME answers
handleDNAME := func(a dns.RR) {
// DNAME record, store temporary watch
rr, ok := a.(*dns.DNAME)
if !ok {
log.Error("DNS-Proxy received invalid DNAME record in reply")
return
}

// get TTL
ttl := a.Header().Ttl

switch a.Header().Rrtype {
case dns.TypeA:
// A Record, get IPv4 address
rr, ok := a.(*dns.A)
if !ok {
log.Error("DNS-Proxy received invalid A record in reply")
continue
}
report := NewReport(name, rr.A, ttl)
p.reports <- report
report.Wait()

case dns.TypeAAAA:
// AAAA Record, get IPv6 address
rr, ok := a.(*dns.AAAA)
if !ok {
log.Error("DNS-Proxy received invalid AAAA record in reply")
continue
}
report := NewReport(name, rr.AAAA, ttl)
p.reports <- report
report.Wait()

case dns.TypeCNAME:
// CNAME record, store temporary watch
rr, ok := a.(*dns.CNAME)
if !ok {
log.Error("DNS-Proxy received invalid CNAME record in reply")
log.WithFields(log.Fields{
"target": rr.Target,
"ttl": rr.Hdr.Ttl,
}).Debug("DNS-Proxy received DNAME in reply")
p.watches.AddTempDNAME(rr.Target, rr.Hdr.Ttl)
}

// handler for CNAME answers
handleCNAME := func(a dns.RR) {
// CNAME record, store temporary watch
rr, ok := a.(*dns.CNAME)
if !ok {
log.Error("DNS-Proxy received invalid CNAME record in reply")
return
}
log.WithFields(log.Fields{
"target": rr.Target,
"ttl": rr.Hdr.Ttl,
}).Debug("DNS-Proxy received CNAME in reply")
p.watches.AddTempCNAME(rr.Target, rr.Hdr.Ttl)
}

// handler for A answers
handleA := func(a dns.RR) {
// A Record, get IPv4 address
rr, ok := a.(*dns.A)
if !ok {
log.Error("DNS-Proxy received invalid A record in reply")
return
}
report := NewReport(rr.Hdr.Name, rr.A, rr.Hdr.Ttl)
p.reports <- report
report.Wait()
}

// handler for AAAA answers
handleAAAA := func(a dns.RR) {
// AAAA Record, get IPv6 address
rr, ok := a.(*dns.AAAA)
if !ok {
log.Error("DNS-Proxy received invalid AAAA record in reply")
return
}
report := NewReport(rr.Hdr.Name, rr.AAAA, rr.Hdr.Ttl)
p.reports <- report
report.Wait()
}

// handle DNAME and CNAME records before A and AAAA records to make
// sure temporary watches are set before checking address records
for _, m := range []map[uint16]func(dns.RR){
{dns.TypeDNAME: handleDNAME},
{dns.TypeCNAME: handleCNAME},
{dns.TypeA: handleA, dns.TypeAAAA: handleAAAA},
} {
for _, a := range reply.Answer {
// ignore domain names we do not watch
name := a.Header().Name
if !p.watches.Contains(name) {
// not on watch list, ignore answer
continue
}
log.WithFields(log.Fields{
"target": rr.Target,
"ttl": ttl,
}).Debug("DNS-Proxy received CNAME in reply")
p.watches.AddTemp(rr.Target, ttl)

case dns.TypeDNAME:
// DNAME record, store temporary watch
rr, ok := a.(*dns.DNAME)
if !ok {
log.Error("DNS-Proxy received invalid DNAME record in reply")
continue

// handle record types
typ := a.Header().Rrtype
if m[typ] != nil {
m[typ](a)
}
log.WithFields(log.Fields{
"target": rr.Target,
"ttl": ttl,
}).Debug("DNS-Proxy received DNAME in reply")
p.watches.AddTemp(rr.Target, ttl)
}
}

Expand Down
81 changes: 81 additions & 0 deletions internal/dnsproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,87 @@ func TestProxyHandleRequest(t *testing.T) {
p.handleRequest(&responseWriter{}, &dns.Msg{Question: []dns.Question{{Name: "test.example.com."}}})
}

// TestProxyHandleRequest tests handleRequest of Proxy, DNS records.
// This tests different CNAME, DNAME, A, AAAA combinations.
func TestProxyHandleRequestRecords(t *testing.T) {
// dns records
dname, _ := dns.NewRR("test.example.com 3600 IN DNAME example.com.")
cname, _ := dns.NewRR("test.example.com 3600 IN CNAME example.com.")
a, _ := dns.NewRR("example.com. 3600 IN A 127.0.0.1")
aaaa, _ := dns.NewRR("example.com. 3600 IN AAAA ::1")

// answers to test with CNAME, DNAME, A, AAAA combinations
answers := [][]dns.RR{
{cname, a, aaaa},
{aaaa, a, cname},
{dname, a, aaaa},
{aaaa, a, dname},
{dname, cname, aaaa, a},
{cname, aaaa, a, dname},
{aaaa, a, dname, cname},
{aaaa, a, cname, dname},
}

// start test server that returns answers
answersChan := make(chan []dns.RR, len(answers))
for _, a := range answers {
answersChan <- a
}
s := getTestDNSServer(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
reply := &dns.Msg{}
reply.SetReply(r)

reply.Answer = <-answersChan
if err := w.WriteMsg(reply); err != nil {
log.WithError(err).Error("error sending reply")
}
}))
defer func() { _ = s.Shutdown() }()

// test helper function
handle := func() []*Report {
// start proxy with remotes and watches
p := NewProxy(getTestConfig())
p.SetRemotes(map[string][]string{".": {s.Addr}})
p.SetWatches([]string{"test.example.com."})

// collect reports
reports := []*Report{}
reportsDone := make(chan struct{})
go func() {
defer close(reportsDone)
for r := range p.Reports() {
reports = append(reports, r)
r.Done()
}
}()

// handle request and return reports
p.handleRequest(&responseWriter{}, &dns.Msg{Question: []dns.Question{
{Name: "test.example.com."}}})
close(p.reports)
<-reportsDone

return reports
}

// test CNAME, DNAME, A, AAAA combinations in answers
for i := range answers {
reports := handle()
if len(reports) != 2 {
t.Fatalf("invalid reports for run %d: %v", i, reports)
}
for _, r := range reports {
if !r.IP.Equal(net.ParseIP("127.0.0.1")) &&
!r.IP.Equal(net.ParseIP("::1")) {

t.Errorf("invalid report for run %d: %v", i, r)
}
}

}
}

// TestProxyStartStop tests Start and Stop of Proxy.
func TestProxyStartStop(_ *testing.T) {
// tcp and udp listeners
Expand Down
74 changes: 49 additions & 25 deletions internal/dnsproxy/watches.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ type tempWatch struct {
type Watches struct {
sync.RWMutex
m map[string]bool
t map[string]*tempWatch
// temporary CNAMEs
c map[string]*tempWatch
// temporary DNAMEs
d map[string]*tempWatch

done chan struct{}
closed chan struct{}
Expand All @@ -37,12 +40,23 @@ func (w *Watches) Add(domain string) {
w.m[domain] = true
}

// AddTemp adds a temporary domain to the watch list with a ttl.
func (w *Watches) AddTemp(domain string, ttl uint32) {
// AddTempCNAME adds a temporary CNAME domain to the watch list with a ttl.
func (w *Watches) AddTempCNAME(domain string, ttl uint32) {
w.Lock()
defer w.Unlock()

w.t[domain] = &tempWatch{
w.c[domain] = &tempWatch{
ttl: ttl,
updated: true,
}
}

// AddTempDNAME adds a temporary DNAME domain to the watch list with a ttl.
func (w *Watches) AddTempDNAME(domain string, ttl uint32) {
w.Lock()
defer w.Unlock()

w.d[domain] = &tempWatch{
ttl: ttl,
updated: true,
}
Expand All @@ -54,7 +68,8 @@ func (w *Watches) Remove(domain string) {
defer w.Unlock()

delete(w.m, domain)
delete(w.t, domain)
delete(w.c, domain)
delete(w.d, domain)
}

// cleanTemp removes expired temporary entries from the watch list and reduces
Expand All @@ -64,21 +79,23 @@ func (w *Watches) cleanTemp(interval uint32) {
w.Lock()
defer w.Unlock()

for d, t := range w.t {
if t.updated {
// mark new entries as old
t.updated = false
continue
}
for _, temps := range []map[string]*tempWatch{w.c, w.d} {
for d, t := range temps {
if t.updated {
// mark new entries as old
t.updated = false
continue
}

if t.ttl < interval {
// delete expired entry
delete(w.t, d)
continue
}
if t.ttl < interval {
// delete expired entry
delete(temps, d)
continue
}

// reduce ttl
t.ttl -= interval
// reduce ttl
t.ttl -= interval
}
}
}

Expand Down Expand Up @@ -112,7 +129,8 @@ func (w *Watches) Flush() {
defer w.Unlock()

w.m = make(map[string]bool)
w.t = make(map[string]*tempWatch)
w.c = make(map[string]*tempWatch)
w.d = make(map[string]*tempWatch)
}

// Contains returns whether the domain is in the watch list.
Expand All @@ -128,15 +146,20 @@ func (w *Watches) Contains(domain string) bool {
// get label indexes and find matching domains
labels := dns.Split(domain)
if labels == nil {
// root domain
// TODO: remove temp domain check here?
return w.m["."] || w.t["."] != nil
// root domain, not supported in watch list
return false
}

// try finding exact domain name in temporary CNAMEs
if w.c[domain] != nil {
return true
}

// try finding longest matching domain name
// try finding longest matching domain name in watched domains and
// temporary DNAMEs
for _, i := range labels {
d := domain[i:]
if w.m[d] || w.t[d] != nil {
if w.m[d] || w.d[d] != nil {
return true
}
}
Expand All @@ -155,7 +178,8 @@ func (w *Watches) Close() {
func NewWatches() *Watches {
w := &Watches{
m: make(map[string]bool),
t: make(map[string]*tempWatch),
c: make(map[string]*tempWatch),
d: make(map[string]*tempWatch),

done: make(chan struct{}),
closed: make(chan struct{}),
Expand Down
Loading

0 comments on commit 99d2ecd

Please sign in to comment.