Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable DoT connection reuse while requests are in flight #269

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions upstream/upstream_dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package upstream

import (
"fmt"
"net"
"net/url"
"runtime"
"sync"

"github.com/AdguardTeam/golibs/errors"
Expand Down Expand Up @@ -53,14 +53,16 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
}

p.RLock()
poolConn, err := p.pool.Get()
poolConnAndStore, err := p.pool.Get()
// Put the connection right back in to allow the connection to be reused while requests are in flight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is generally the same as using a single connection and not using pool at all. Why keeping it then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting rid of the connections pool and using a single connection is possible, but it does have visible downsides. Namely, head-of-line blocking will be a huge problem when the connection is not stable.

However, if we operate locally (for instance, dnsproxy proxies a recursor on the same server) using pipelining is in theory better.

Here's what I suggest:

  1. Try implementing a different type of upstream that uses a single connection (in a separate file).
  2. Write tests for it and check that pipelining actually works and that it keeps that single connection alive.
  3. IMPORTANT: we need to be ready that the server can close the connection any time. In this case we should first try re-establishing the connection and repeating the request.

Then I'll examine the implementation and if it's okay we can merge it with existing DoT and DNS-over-TCP upstreams and allow enabling pipelining mode via upstream.Options.

p.pool.Put(poolConnAndStore)
p.RUnlock()
if err != nil {
return nil, fmt.Errorf("getting connection to %s: %w", p.Address(), err)
}

logBegin(p.Address(), m)
reply, err = p.exchangeConn(poolConn, m)
reply, err = p.exchangeConn(poolConnAndStore, m)
logFinish(p.Address(), err)
if err != nil {
log.Tracef("The TLS connection is expired due to %s", err)
Expand All @@ -70,49 +72,89 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
// We are forcing creation of a new connection instead of calling Get() again
// as there's no guarantee that other pooled connections are intact
p.RLock()
poolConn, err = p.pool.Create()
poolConnAndStore, err = p.pool.Create()
p.RUnlock()
if err != nil {
return nil, fmt.Errorf("creating new connection to %s: %w", p.Address(), err)
}

// Retry sending the DNS request
logBegin(p.Address(), m)
reply, err = p.exchangeConn(poolConn, m)
reply, err = p.exchangeConn(poolConnAndStore, m)
logFinish(p.Address(), err)
}

if err == nil {
p.RLock()
p.pool.Put(poolConn)
p.RUnlock()
}
return reply, err
}

func (p *dnsOverTLS) exchangeConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) {
func (p *dnsOverTLS) exchangeConn(connAndStore *connAndStore, m *dns.Msg) (reply *dns.Msg, err error) {
defer func() {
if err == nil {
return
}

if cerr := conn.Close(); cerr != nil {
if cerr := connAndStore.conn.Close(); cerr != nil {
err = &errors.Pair{Returned: err, Deferred: cerr}
}
}()

dnsConn := dns.Conn{Conn: conn}
dnsConn := dns.Conn{Conn: connAndStore.conn}

err = dnsConn.WriteMsg(m)
if err != nil {
return nil, fmt.Errorf("sending request to %s: %w", p.Address(), err)
}

reply, err = dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err)
} else if reply.Id != m.Id {
err = dns.ErrId
// Since we might receive out-of-order responses when processing multiple queries through a single upstream (cf.
// PR #269), we will store all responses that don't match our DNS ID and retry until we find the response we are
// looking for (either by receiving it directly or by finding it in the stored responses).
responseFound := false
present := false
for !responseFound {
connAndStore.Lock()

// has someone already received our response?
reply, present = connAndStore.store[m.Id]
if present { // matching response in store
log.Tracef("Found matching ID in store for request %d", m.Id)
delete(connAndStore.store, m.Id) // delete response from store
responseFound = true
} else { // no matching response in store
reply, err = dnsConn.ReadMsg()
if err != nil {
connAndStore.Unlock()
return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err)
} else if reply.Id != m.Id {
// not the response we were looking for -> store it in the store
log.Tracef("Received unknown ID %d, storing in store for later use", reply.Id)
connAndStore.store[reply.Id] = reply
} else {
responseFound = true
}
}
connAndStore.Unlock()

// yield to scheduler if we added something to the store
if !responseFound {
runtime.Gosched()
}
}

// Match response QNAME, QCLASS, and QTYPE to query according to RFC 7766
// (https://www.rfc-editor.org/rfc/rfc7766#section-7)
if len(reply.Question) != 0 && len(m.Question) != 0 {
if reply.Question[0].Name != m.Question[0].Name {
err = fmt.Errorf("Query and response QNAME do not match; received %s, expected %s", reply.Question[0].Name, m.Question[0].Name)
return reply, err
}
if reply.Question[0].Qtype != m.Question[0].Qtype {
err = fmt.Errorf("Query and response QTYPE do not match; received %d, expected %d", reply.Question[0].Qtype, m.Question[0].Qtype)
return reply, err
}
if reply.Question[0].Qclass != m.Question[0].Qclass {
err = fmt.Errorf("Query and response QCLASS do not match; received %d, expected %d", reply.Question[0].Qclass, m.Question[0].Qclass)
return reply, err
}
}

return reply, err
Expand Down
28 changes: 20 additions & 8 deletions upstream/upstream_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)

// dialTimeout is the global timeout for establishing a TLS connection.
Expand Down Expand Up @@ -36,16 +37,24 @@ type TLSPool struct {
boot *bootstrapper

// conns is the list of connections available in the pool.
conns []net.Conn
conns []*connAndStore
// connsMutex protects conns.
connsMutex sync.Mutex
}

// connAndStore is a sturct that assigns a store for out-of-order responses to each connection.
// We need this to process multiple queries through a single upstream (cf. PR #269).
type connAndStore struct {
conn net.Conn
store map[uint16]*dns.Msg // needed to save out-of-order responses when reusing the connection
sync.Mutex // protects store
}

// Get gets a connection from the pool (if there's one available) or creates
// a new TLS connection.
func (n *TLSPool) Get() (net.Conn, error) {
func (n *TLSPool) Get() (*connAndStore, error) {
// Get the connection from the slice inside the lock.
var c net.Conn
var c *connAndStore
n.connsMutex.Lock()
num := len(n.conns)
if num > 0 {
Expand All @@ -57,11 +66,11 @@ func (n *TLSPool) Get() (net.Conn, error) {

// If we got connection from the slice, update deadline and return it.
if c != nil {
err := c.SetDeadline(time.Now().Add(dialTimeout))
err := c.conn.SetDeadline(time.Now().Add(dialTimeout))

// If deadLine can't be updated it means that connection was already closed
if err == nil {
log.Tracef("Returning existing connection to %s with updated deadLine", c.RemoteAddr())
log.Tracef("Returning existing connection to %s with updated deadLine", c.conn.RemoteAddr())
return c, nil
}
}
Expand All @@ -70,7 +79,7 @@ func (n *TLSPool) Get() (net.Conn, error) {
}

// Create creates a new connection for the pool (but not puts it there).
func (n *TLSPool) Create() (net.Conn, error) {
func (n *TLSPool) Create() (*connAndStore, error) {
tlsConfig, dialContext, err := n.boot.get()
if err != nil {
return nil, err
Expand All @@ -82,11 +91,14 @@ func (n *TLSPool) Create() (net.Conn, error) {
return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err)
}

return conn, nil
// initialize the store
store := make(map[uint16]*dns.Msg)

return &connAndStore{conn: conn, store: store}, nil
}

// Put returns the connection to the pool.
func (n *TLSPool) Put(c net.Conn) {
func (n *TLSPool) Put(c *connAndStore) {
if c == nil {
return
}
Expand Down
22 changes: 11 additions & 11 deletions upstream/upstream_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func TestTLSPoolReconnect(t *testing.T) {

// Now let's close the pooled connection and return it back to the pool.
p := u.(*dnsOverTLS)
conn, _ := p.pool.Get()
conn.Close()
p.pool.Put(conn)
connAndStore, _ := p.pool.Get()
connAndStore.conn.Close()
p.pool.Put(connAndStore)

// Send the second test message.
req = createTestMessage()
Expand Down Expand Up @@ -72,42 +72,42 @@ func TestTLSPoolDeadLine(t *testing.T) {
p := u.(*dnsOverTLS)

// Now let's get connection from the pool and use it
conn, err := p.pool.Get()
connAndStore, err := p.pool.Get()
if err != nil {
t.Fatalf("couldn't get connection from pool: %s", err)
}
response, err = p.exchangeConn(conn, req)
response, err = p.exchangeConn(connAndStore, req)
if err != nil {
t.Fatalf("first DNS message failed: %s", err)
}
requireResponse(t, req, response)

// Update connection's deadLine and put it back to the pool
err = conn.SetDeadline(time.Now().Add(10 * time.Hour))
err = connAndStore.conn.SetDeadline(time.Now().Add(10 * time.Hour))
if err != nil {
t.Fatalf("can't set new deadLine for connection. Looks like it's already closed: %s", err)
}
p.pool.Put(conn)
p.pool.Put(connAndStore)

// Get connection from the pool and reuse it
conn, err = p.pool.Get()
connAndStore, err = p.pool.Get()
if err != nil {
t.Fatalf("couldn't get connection from pool: %s", err)
}
response, err = p.exchangeConn(conn, req)
response, err = p.exchangeConn(connAndStore, req)
if err != nil {
t.Fatalf("first DNS message failed: %s", err)
}
requireResponse(t, req, response)

// Set connection's deadLine to the past and try to reuse it
err = conn.SetDeadline(time.Now().Add(-10 * time.Hour))
err = connAndStore.conn.SetDeadline(time.Now().Add(-10 * time.Hour))
if err != nil {
t.Fatalf("can't set new deadLine for connection. Looks like it's already closed: %s", err)
}

// Connection with expired deadLine can't be used
response, err = p.exchangeConn(conn, req)
response, err = p.exchangeConn(connAndStore, req)
if err == nil {
t.Fatalf("this connection should be already closed, got response %s", response)
}
Expand Down