Skip to content

Commit

Permalink
Merge pull request #239 from DasSkelett/fix/dns-proxy-big-packets
Browse files Browse the repository at this point in the history
  • Loading branch information
DasSkelett authored Aug 13, 2022
2 parents 6d11b95 + dae0292 commit c8da2fc
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 13 deletions.
File renamed without changes.
26 changes: 26 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Unit tests

on:
push:
branches:
- master
pull_request:

jobs:
test-go:
runs-on: ubuntu-latest
strategy:
matrix:
go: [ '1.18', '1.19' ]
name: Go ${{ matrix.go }} tests
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup Go
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
cache: true
- name: Run go test
run: go test -v ./...

3 changes: 2 additions & 1 deletion cmd/serve/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ func (cmd *servecmd) Run() {
ListenAddr: listenAddr,
})
if err != nil {
logrus.Error(errors.Wrap(err, "failed to start dns server"))
logrus.Error(errors.Wrap(err, "failed to create dns server"))
return
}
dns.ListenAndServe()
defer dns.Close()
if conf.DNS.Domain != "" {
// Generate initial DNS zone for registered devices
Expand Down
50 changes: 43 additions & 7 deletions internal/dnsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ import (
)

type DNSProxy struct {
client *dns.Client
cache *cache.Cache
upstream []string
udpClient *dns.Client
tcpClient *dns.Client
cache *cache.Cache
upstream []string
}

// ServeDNS is called by the mux from the listening servers.
func (d *DNSProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer func() {
if err := recover(); err != nil {
Expand All @@ -32,13 +34,17 @@ func (d *DNSProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
case dns.OpcodeQuery:
// Remove EDNS0 Client Subnet information as we don't handle them in the cache
purgeECS(r)
m, err := d.Lookup(r)
outQuery := r.Copy()
// Set EDNS BufSize for forwarding to upstream
ensureEDNS0BufSize(outQuery)
m, err := d.Lookup(outQuery)
if err != nil {
logrus.Errorf("failed lookup record with error: %s\n%s", err.Error(), r)
HandleFailed(w, r)
return
}
m.SetReply(r)
truncateIfRequired(m, r, w.RemoteAddr().Network())
err = w.WriteMsg(m)
if err != nil {
logrus.Errorf("failed write response for client with error: %s\n%s", err.Error(), r)
Expand All @@ -56,25 +62,36 @@ func (d *DNSProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {

}

// Lookup first checks the cache for a matching response, and if unsuccessful queries the upstream resolvers.
func (d *DNSProxy) Lookup(m *dns.Msg) (*dns.Msg, error) {
key := makekey(m)

// check the cache first
if item, found := d.cache.Get(key); found {
logrus.Debugf("dns cache hit %s", prettyPrintMsg(m))
return item.(*dns.Msg), nil
return item.(*dns.Msg).Copy(), nil
}

// fallback to upstream exchange
// TODO disable upstream after certain amount of failures?
var response *dns.Msg
var firstErr error
for _, upstream := range d.upstream {
resp, _, err := d.client.Exchange(m, net.JoinHostPort(upstream, "53"))
target := net.JoinHostPort(upstream, "53")
resp, _, err := d.udpClient.Exchange(m, target)
if err != nil && firstErr == nil {
logrus.Warnf(errors.Wrap(err, fmt.Sprintf("DNS lookup failed for upstream %s", upstream)).Error())
firstErr = err
} else if err == nil {
// Retry truncated responses over TCP
if resp.Truncated {
resp, _, err = d.tcpClient.Exchange(m, target)
if err != nil && firstErr == nil {
logrus.Warnf(errors.Wrap(err, fmt.Sprintf("DNS lookup failed over TCP for upstream %s", upstream)).Error())
firstErr = err
continue
}
}
response = resp
break
}
Expand All @@ -89,7 +106,7 @@ func (d *DNSProxy) Lookup(m *dns.Msg) (*dns.Msg, error) {
d.cache.Set(key, response, ttl)
}

return response, nil
return response.Copy(), nil
}

func purgeECS(m *dns.Msg) {
Expand All @@ -101,3 +118,22 @@ func purgeECS(m *dns.Msg) {
}
}
}

func ensureEDNS0BufSize(m *dns.Msg) {
if opt := m.IsEdns0(); opt != nil {
opt.SetUDPSize(1232)
} else {
m.SetEdns0(1232, false)
}
}

func truncateIfRequired(response *dns.Msg, original *dns.Msg, transport string) {
size := dns.MinMsgSize
if transport == "tcp" {
size = dns.MaxMsgSize
} else if opt := original.IsEdns0(); opt != nil {
size = int(opt.UDPSize())
}
logrus.Debugf("truncating to %d", size)
response.Truncate(size)
}
58 changes: 58 additions & 0 deletions internal/dnsproxy/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package dnsproxy

import (
"context"
"net"
"testing"
"time"
)

var ffmucUpstreams, _ = net.LookupHost("dns.ffmuc.net")

func TestDNSProxy_ServeDNS(t *testing.T) {
const listen = "[::1]:8053"

resolver := net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{Timeout: time.Second}
return d.DialContext(ctx, network, listen)
},
}

server, err := New(DNSServerOpts{
Domain: "",
ListenAddr: []string{listen},
Upstream: ffmucUpstreams,
})
server.ListenAndServe()
defer func() { _ = server.Close() }()

if err != nil {
t.Fatal(err)
}

t.Run("Reply over 1300 bytes", func(t *testing.T) {
_, err := resolver.LookupTXT(context.Background(), "cloudflare.com.")
if err != nil {
t.Error(err)
return
}
})
t.Run("Reply over 1500 bytes", func(t *testing.T) {
records, err := resolver.LookupTXT(context.Background(), "txtfill1500.go.dnscheck.tools.")
if err != nil {
t.Error(err)
return
}
var containsBigRecord bool
for _, r := range records {
if len(r) >= 1500 {
containsBigRecord = true
}
}
if !containsBigRecord {
t.Error("missing big TXT record, packet probably truncated")
}
})
}
39 changes: 34 additions & 5 deletions internal/dnsproxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,22 @@ type DNSServer struct {
auth *DNSAuth
}

// New returns a pointer to a DNSServer configured using opts DNSServerOpts.
// The returned server needs to be started using DNSServer.ListenAndServe()
func New(opts DNSServerOpts) (*DNSServer, error) {
if len(opts.Upstream) == 0 {
return nil, errors.New("at least 1 upstream dns server is required for the dns proxy server to function")
}

logrus.Infof("starting dns server on %s with upstreams: %s", strings.Join(opts.ListenAddr, ", "), strings.Join(opts.Upstream, ", "))

dnsServer := &DNSServer{
servers: []*dns.Server{},
proxy: &DNSProxy{
client: &dns.Client{
udpClient: &dns.Client{
SingleInflight: true,
Timeout: 5 * time.Second,
},
tcpClient: &dns.Client{
Net: "tcp",
SingleInflight: true,
Timeout: 5 * time.Second,
},
Expand Down Expand Up @@ -72,15 +77,39 @@ func New(opts DNSServerOpts) (*DNSServer, error) {
dnsServer.servers = append(dnsServer.servers, tcpServer)
}

for _, server := range dnsServer.servers {
return dnsServer, nil
}

// ListenAndServe starts the DNSServer and waits until all listeners are up.
func (d *DNSServer) ListenAndServe() {
var sb strings.Builder
for i, s := range d.servers {
sb.WriteString(s.Addr)
sb.WriteString("/")
sb.WriteString(s.Net)
if i < len(d.servers)-1 {
sb.WriteString(", ")
}
}

logrus.Infof("starting dns server on %s with upstreams: %s", sb.String(), strings.Join(d.proxy.upstream, ", "))

var wg sync.WaitGroup

for _, server := range d.servers {
wg.Add(1)
server.NotifyStartedFunc = func() {
wg.Done()
}
go func(server *dns.Server) {
if err := server.ListenAndServe(); err != nil {
logrus.Error(errors.Errorf("failed to start DNS server on %s/%s: %s", server.Addr, server.Net, err))
wg.Done()
}
}(server)
}

return dnsServer, nil
wg.Wait()
}

func (d *DNSServer) Close() error {
Expand Down

0 comments on commit c8da2fc

Please sign in to comment.