Skip to content

Commit

Permalink
fix(settings): allow listening on a random port and using the empty l…
Browse files Browse the repository at this point in the history
…istening address
  • Loading branch information
qdm12 committed Nov 15, 2023
1 parent ad2b4d2 commit cbbf011
Show file tree
Hide file tree
Showing 15 changed files with 165 additions and 41 deletions.
14 changes: 9 additions & 5 deletions internal/config/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ type Settings struct {
// Upstream is the DNS upstream connection type
// and can be either 'dot' or 'doh'.
// It defaults to 'dot' if left uset.
Upstream string
ListeningAddress string
Upstream string
// ListeningAddress is the DNS server listening address.
// It can be set to the empty string to listen on all interfaces
// on a random available port.
// It defaults to ":53".
ListeningAddress *string
Block Block
Cache Cache
DoH DoH
Expand All @@ -32,7 +36,7 @@ type Settings struct {

func (s *Settings) SetDefaults() {
s.Upstream = gosettings.DefaultString(s.Upstream, "dot")
s.ListeningAddress = gosettings.DefaultString(s.ListeningAddress, ":53")
s.ListeningAddress = gosettings.DefaultPointer(s.ListeningAddress, ":53")
s.Block.setDefaults()
s.Cache.setDefaults()
s.DoH.setDefaults()
Expand All @@ -56,7 +60,7 @@ func (s *Settings) Validate() (err error) {
}

const privilegedAllowedPort = 53
err = validate.ListeningAddress(s.ListeningAddress, os.Getuid(), privilegedAllowedPort)
err = validate.ListeningAddress(*s.ListeningAddress, os.Getuid(), privilegedAllowedPort)
if err != nil {
return fmt.Errorf("listening address: %w", err)
}
Expand Down Expand Up @@ -94,7 +98,7 @@ func (s *Settings) ToLinesNode() (node *gotree.Node) {
node = gotree.New("Settings:")

node.Appendf("DNS upstream connection: %s", s.Upstream)
node.Appendf("DNS server listening address: %s", s.ListeningAddress)
node.Appendf("DNS server listening address: %s", *s.ListeningAddress)

switch s.Upstream {
case "dot":
Expand Down
2 changes: 1 addition & 1 deletion internal/config/sources/env/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (r *Reader) Read() (settings settings.Settings, err error) {
}

settings.Upstream = r.env.String("UPSTREAM_TYPE")
settings.ListeningAddress = r.env.String("LISTENING_ADDRESS")
settings.ListeningAddress = r.env.Get("LISTENING_ADDRESS")

settings.Block, err = r.readBlock()
if err != nil {
Expand Down
28 changes: 28 additions & 0 deletions internal/mockhelp/regex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package mockhelp

import (
"regexp"
)

func NewMatcherRegex(regex string) *MatcherRegex {
return &MatcherRegex{
regex: regexp.MustCompile(regex),
}
}

type MatcherRegex struct {
regex *regexp.Regexp
}

func (m *MatcherRegex) String() string {
return "must match regex " + m.regex.String()
}

func (m *MatcherRegex) Matches(x interface{}) bool {
s, ok := x.(string)
if !ok {
return false
}

return m.regex.MatchString(s)
}
3 changes: 2 additions & 1 deletion internal/setup/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
noopmetrics "github.com/qdm12/dns/v2/pkg/doh/metrics/noop"
prometheusmetrics "github.com/qdm12/dns/v2/pkg/doh/metrics/prometheus"
"github.com/qdm12/dns/v2/pkg/metrics/prometheus"
"github.com/qdm12/gosettings"
)

func dohServer(userSettings settings.Settings,
Expand All @@ -27,7 +28,7 @@ func dohServer(userSettings settings.Settings,

settings := doh.ServerSettings{
Resolver: resolverSettings,
ListeningAddress: userSettings.ListeningAddress,
ListeningAddress: gosettings.CopyPointer(userSettings.ListeningAddress),
Middlewares: toDoHMiddlewares(middlewares),
Logger: logger,
}
Expand Down
3 changes: 2 additions & 1 deletion internal/setup/dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
noopmetrics "github.com/qdm12/dns/v2/pkg/dot/metrics/noop"
prometheusmetrics "github.com/qdm12/dns/v2/pkg/dot/metrics/prometheus"
"github.com/qdm12/dns/v2/pkg/metrics/prometheus"
"github.com/qdm12/gosettings"
)

func dotServer(userSettings settings.Settings,
Expand All @@ -23,7 +24,7 @@ func dotServer(userSettings settings.Settings,

settings := dot.ServerSettings{
Resolver: resolverSettings,
ListeningAddress: userSettings.ListeningAddress,
ListeningAddress: gosettings.CopyPointer(userSettings.ListeningAddress),
Middlewares: toDoTMiddlewares(middlewares),
Logger: logger,
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/doh/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package doh

func ptrTo[T any](value T) *T {
return &value
}
17 changes: 13 additions & 4 deletions pkg/doh/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@ func Test_Resolver(t *testing.T) {
}

func Test_Server(t *testing.T) {
server, err := NewServer(ServerSettings{})
server, err := NewServer(ServerSettings{
ListeningAddress: ptrTo(""),
})
require.NoError(t, err)

runError, startErr := server.Start()
require.NoError(t, startErr)

listeningAddress, err := server.ListeningAddress()
require.NoError(t, err)

resolver := &net.Resolver{
PreferGo: true,
StrictErrors: true,
Dial: func(ctx context.Context, network string, address string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: time.Second}
return dialer.DialContext(ctx, "udp", "127.0.0.1:53")
return dialer.DialContext(ctx, "udp", listeningAddress.String())
},
}

Expand Down Expand Up @@ -232,7 +237,7 @@ func Test_Server_Mocks(t *testing.T) {
filterMiddleware := filtermiddleware.New(filter)

logger := NewMockLogger(ctrl)
logger.EXPECT().Info("DNS server listening on :53")
logger.EXPECT().Info(mockhelp.NewMatcherRegex("DNS server listening on .*:[1-9][0-9]{0,4}"))

metrics := NewMockMetrics(ctrl)
metrics.EXPECT().
Expand Down Expand Up @@ -265,18 +270,22 @@ func Test_Server_Mocks(t *testing.T) {
Resolver: ResolverSettings{
Metrics: metrics,
},
ListeningAddress: ptrTo(""),
})
require.NoError(t, err)

runError, startErr := server.Start()
require.NoError(t, startErr)

listeningAddress, err := server.ListeningAddress()
require.NoError(t, err)

resolver := &net.Resolver{
PreferGo: true,
StrictErrors: true,
Dial: func(ctx context.Context, network string, address string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: time.Second}
return dialer.DialContext(ctx, "udp", "127.0.0.1:53")
return dialer.DialContext(ctx, "udp", listeningAddress.String())
},
}

Expand Down
43 changes: 37 additions & 6 deletions pkg/doh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package doh

import (
"context"
"errors"
"fmt"
"net"
"sync"

"github.com/miekg/dns"
Expand Down Expand Up @@ -53,12 +55,16 @@ func (s *Server) Start() (runError <-chan error, startErr error) {
s.runningMutex.Unlock()

handlerCtx, handlerCancel := context.WithCancel(context.Background())
defer func() {
if startErr != nil {
handlerCancel()
}
}()

var handler dns.Handler
var err error
handler, err = newDNSHandler(handlerCtx, s.settings)
if err != nil {
handlerCancel()
return nil, fmt.Errorf("creating DNS handler: %w", err)
}

Expand All @@ -68,10 +74,20 @@ func (s *Server) Start() (runError <-chan error, startErr error) {

s.stop = make(chan struct{})
s.done = new(sync.WaitGroup)

listeningAddress, err := net.ResolveUDPAddr("udp", *s.settings.ListeningAddress)
if err != nil {
return nil, fmt.Errorf("resolving listening address: %w", err)
}

udpListener, err := net.ListenUDP("udp", listeningAddress)
if err != nil {
return nil, fmt.Errorf("creating UDP listener: %w", err)
}

s.dnsServer = dns.Server{
Addr: s.settings.ListeningAddress,
Net: "udp",
Handler: handler,
PacketConn: udpListener,
Handler: handler,
}

var ready sync.WaitGroup
Expand All @@ -89,9 +105,9 @@ func (s *Server) Start() (runError <-chan error, startErr error) {
s.done.Add(1)
go func() {
defer s.done.Done()
s.settings.Logger.Info("DNS server listening on " + s.dnsServer.Addr)
s.settings.Logger.Info("DNS server listening on " + s.dnsServer.PacketConn.LocalAddr().String())
ready.Done()
err := s.dnsServer.ListenAndServe()
err := s.dnsServer.ActivateAndServe()
s.runningMutex.Lock()
s.running = false
s.runningMutex.Unlock()
Expand Down Expand Up @@ -131,3 +147,18 @@ func (s *Server) Stop() (err error) {

return err
}

var (
ErrServerNotRunning = errors.New("server not running")
)

func (s *Server) ListeningAddress() (address net.Addr, err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()

if !s.running {
return nil, fmt.Errorf("%w", ErrServerNotRunning)
}

return s.dnsServer.PacketConn.LocalAddr(), nil
}
10 changes: 5 additions & 5 deletions pkg/doh/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

type ServerSettings struct {
Resolver ResolverSettings
ListeningAddress string
ListeningAddress *string
// Middlewares is a list of middlewares to use.
// The first one is the first wrapper, and the last one
// is the last wrapper of the handlers in the chain.
Expand Down Expand Up @@ -56,7 +56,7 @@ type SelfDNS struct {

func (s *ServerSettings) SetDefaults() {
s.Resolver.SetDefaults()
s.ListeningAddress = gosettings.DefaultString(s.ListeningAddress, ":53")
s.ListeningAddress = gosettings.DefaultPointer(s.ListeningAddress, ":53")
s.Logger = gosettings.DefaultInterface(s.Logger, lognoop.New())
}

Expand Down Expand Up @@ -89,9 +89,9 @@ func (s ServerSettings) Validate() (err error) {
}

const defaultUDPPort = 53
err = validate.ListeningAddress(s.ListeningAddress, os.Getuid(), defaultUDPPort)
err = validate.ListeningAddress(*s.ListeningAddress, os.Getuid(), defaultUDPPort)
if err != nil {
return fmt.Errorf("%w: %s", ErrListeningAddressNotValid, s.ListeningAddress)
return fmt.Errorf("%w: %s", ErrListeningAddressNotValid, *s.ListeningAddress)
}

return nil
Expand Down Expand Up @@ -145,7 +145,7 @@ func (s *SelfDNS) String() string {

func (s *ServerSettings) ToLinesNode() (node *gotree.Node) {
node = gotree.New("DoH server settings:")
node.Appendf("Listening address: %s", s.ListeningAddress)
node.Appendf("Listening address: %s", *s.ListeningAddress)
node.AppendNode(s.Resolver.ToLinesNode())
return node
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/doh/settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Test_ServerSettings_SetDefaults(t *testing.T) {
Metrics: metrics,
Picker: picker,
},
ListeningAddress: ":53",
ListeningAddress: ptrTo(":53"),
}
assert.Equal(t, expectedSettings, s)
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/dot/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package dot

func ptrTo[T any](value T) *T {
return &value
}
17 changes: 13 additions & 4 deletions pkg/dot/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@ func Test_Resolver(t *testing.T) {
}

func Test_Server(t *testing.T) {
server, err := NewServer(ServerSettings{})
server, err := NewServer(ServerSettings{
ListeningAddress: ptrTo(""),
})
require.NoError(t, err)

runError, startErr := server.Start()
require.NoError(t, startErr)

listeningAddress, err := server.ListeningAddress()
require.NoError(t, err)

resolver := &net.Resolver{
PreferGo: true,
StrictErrors: true,
Dial: func(ctx context.Context, network string, address string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: time.Second}
return dialer.DialContext(ctx, "udp", "127.0.0.1:53")
return dialer.DialContext(ctx, "udp", listeningAddress.String())
},
}

Expand Down Expand Up @@ -228,7 +233,7 @@ func Test_Server_Mocks(t *testing.T) {
filterMiddleware := filtermiddleware.New(filter)

logger := NewMockLogger(ctrl)
logger.EXPECT().Info("DNS server listening on :53")
logger.EXPECT().Info(mockhelp.NewMatcherRegex("DNS server listening on .*:[1-9][0-9]{0,4}"))

dotMetrics := NewMockMetrics(ctrl)
dotMetrics.EXPECT().
Expand Down Expand Up @@ -260,18 +265,22 @@ func Test_Server_Mocks(t *testing.T) {
Resolver: ResolverSettings{
Metrics: dotMetrics,
},
ListeningAddress: ptrTo(""),
})
require.NoError(t, err)

runError, startErr := server.Start()
require.NoError(t, startErr)

listeningAddress, err := server.ListeningAddress()
require.NoError(t, err)

resolver := &net.Resolver{
PreferGo: true,
StrictErrors: true,
Dial: func(ctx context.Context, network string, address string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: time.Second}
return dialer.DialContext(ctx, "udp", "127.0.0.1:53")
return dialer.DialContext(ctx, "udp", listeningAddress.String())
},
}

Expand Down
Loading

0 comments on commit cbbf011

Please sign in to comment.