Skip to content

Commit

Permalink
Fix non-root hosts failing on resolving DNS
Browse files Browse the repository at this point in the history
A fix for a bug introduced by btcsuite#2168

Previously, config.Host worked in the following way:
1. Documented as supporting ip addresses only
2. In fact supported "host/path" syntax
3. Did not support "scheme" prefixes, i.e. https://

Not sure this is the desired approach, probably the best thing would
have been to extend config to contain "Scheme" and "Path" fields as well.

However, this was the way it worked.

1. Now Host can contain scheme prefixes "unix://..."
2. Host can no longer contain ".../path"

This PR solves this behavior while maintaining support of the "unix://" flow
as well.

For some reason, "scheme" is named "network" in btcsuite#2168 - I did not change that.

Also remove disambiguation in "network:address:port", where it parsed
"myhost:8888" as network:address instead address:port.
  • Loading branch information
same-id committed Nov 8, 2024
1 parent 24eb815 commit 47faac8
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 58 deletions.
127 changes: 69 additions & 58 deletions rpcclient/infrastructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,41 +759,26 @@ out:
// result, unmarshalling it, and delivering the unmarshalled result to the
// provided response channel.
func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
protocol := "http"
if !c.config.DisableTLS {
protocol = "https"
}

var (
err, lastErr error
lastErr error
backoff time.Duration
httpResponse *http.Response
)

parsedAddr, err := ParseAddressString(c.config.Host)
httpURL, err := c.config.httpURL()
if err != nil {
jReq.responseChan <- &Response{
err: fmt.Errorf("failed to parse address %v", err),
}
return
}

var url string
switch parsedAddr.Network() {
case "unix", "unixpacket":
// Using a placeholder URL because a non-empty URL is required.
// The Unix domain socket is specified in the DialContext.
url = protocol + "://unix"
default:
url = protocol + "://" + c.config.Host
}

tries := 10
for i := 0; i < tries; i++ {
var httpReq *http.Request

bodyReader := bytes.NewReader(jReq.marshalledJSON)
httpReq, err = http.NewRequest("POST", url, bodyReader)
httpReq, err = http.NewRequest("POST", httpURL, bodyReader)
if err != nil {
jReq.responseChan <- &Response{result: nil, err: err}
return
Expand Down Expand Up @@ -1355,16 +1340,21 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) {
}
}

parsedAddr, err := ParseAddressString(config.Host)
parsedDialAddr, err := ParseAddressString(config.Host)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: proxyFunc,
TLSClientConfig: tlsConfig,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial(parsedAddr.Network(), parsedAddr.String())
DialContext: func(_ context.Context, _,
_ string) (net.Conn, error) {

return net.Dial(
parsedDialAddr.Network(),
parsedDialAddr.String(),
)
},
},
Timeout: defaultHTTPTimeout,
Expand All @@ -1373,6 +1363,32 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) {
return &client, nil
}

// httpURL returns the URL to use for HTTP POST requests.
func (config *ConnConfig) httpURL() (string, error) {
protocol := "http"
if !config.DisableTLS {
protocol = "https"
}

parsedAddr, err := ParseAddressString(config.Host)
if err != nil {
return "", fmt.Errorf("error parsing host '%v': %v",
config.Host, err)
}

var httpURL string
switch parsedAddr.Network() {
case "unix", "unixpacket":
// Using a placeholder URL because a non-empty URL is required.
// The Unix domain socket is specified in the DialContext.
httpURL = protocol + "://unix"
default:
httpURL = protocol + "://" + config.Host
}

return httpURL, nil
}

// dial opens a websocket connection using the passed connection configuration
// details.
func dial(config *ConnConfig) (*websocket.Conn, error) {
Expand Down Expand Up @@ -1733,53 +1749,48 @@ func (c *Client) Send() error {
return nil
}

// cutPrefix returns s without the provided leading prefix string
// and reports whether it found the prefix.
// If s doesn't start with prefix, CutPrefix returns s, false.
// If prefix is the empty string, CutPrefix returns s, true.
// Copied from go1.20 version.
func cutPrefix(s, prefix string) (after string, found bool) {
if !strings.HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}

// ParseAddressString converts an address in string format to a net.Addr that is
// compatible with btcd. UDP is not supported because btcd needs reliable
// connections. We accept a custom function to resolve any TCP addresses so
// that caller is able control exactly how resolution is performed.
// connections.
func ParseAddressString(strAddress string) (net.Addr, error) {
var parsedNetwork, parsedAddr string
// Addresses can either be in unix://address, unixpacket://address URL
// format, or just address:port host format for tcp.
if after, ok := cutPrefix(strAddress, "unix://"); ok {
return net.ResolveUnixAddr("unix", after)
}
if after, ok := cutPrefix(strAddress, "unixpacket://"); ok {
return net.ResolveUnixAddr("unixpacket", after)
}

// Addresses can either be in network://address:port format,
// network:address:port, address:port, or just port. We want to support
// all possible types.
if strings.Contains(strAddress, "://") {
parts := strings.Split(strAddress, "://")
parsedNetwork, parsedAddr = parts[0], parts[1]
} else if strings.Contains(strAddress, ":") {
parts := strings.Split(strAddress, ":")
parsedNetwork = parts[0]
parsedAddr = strings.Join(parts[1:], ":")
} else {
parsedAddr = strAddress
// Not supporting :// anywhere in the host or path.
return nil, fmt.Errorf("unsupported protocol in address: %s",
strAddress)
}

// Only TCP and Unix socket addresses are valid. We can't use IP or
// UDP only connections for anything we do in lnd.
switch parsedNetwork {
case "unix", "unixpacket":
return net.ResolveUnixAddr(parsedNetwork, parsedAddr)

case "tcp", "tcp4", "tcp6":
return net.ResolveTCPAddr(parsedNetwork, verifyPort(parsedAddr))

case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram":
return nil, fmt.Errorf("only TCP or unix socket "+
"addresses are supported: %s", parsedAddr)

default:
// We'll now possibly use the local host short circuit
// or parse out an all interfaces listen.
addrWithPort := verifyPort(strAddress)

// Otherwise, we'll attempt to resolve the host.
return net.ResolveTCPAddr("tcp", addrWithPort)
// Parse it as a dummy URL to get the host and port.
u, err := url.Parse("dummy://" + strAddress)
if err != nil {
return nil, err
}
return net.ResolveTCPAddr("tcp", verifyPort(u.Host))
}

// verifyPort makes sure that an address string has both a host and a port.
// If the address is just a port, then we'll assume that the user is using the
// short cut to specify a localhost:port address.
// shortcut to specify a localhost:port address.
func verifyPort(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
Expand All @@ -1801,8 +1812,8 @@ func verifyPort(address string) string {
return net.JoinHostPort(address, "")
}

// In the case that both the host and port are empty, we'll use the
// an empty port.
// In the case that both the host and port are empty, we'll use an empty
// port.
if host == "" && port == "" {
return ":"
}
Expand Down
110 changes: 110 additions & 0 deletions rpcclient/infrastructure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package rpcclient

import (
"testing"

"github.com/stretchr/testify/require"
)

// TestParseAddressString checks different variation of supported and
// unsupported addresses.
func TestParseAddressString(t *testing.T) {
t.Parallel()

// Using localhost only to avoid network calls.
testCases := []struct {
name string
addressString string
expNetwork string
expAddress string
expErrStr string
}{
{
name: "localhost",
addressString: "localhost",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost ip",
addressString: "127.0.0.1",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost ipv6",
addressString: "::1",
expNetwork: "tcp",
expAddress: "[::1]:0",
},
{
name: "localhost and port",
addressString: "localhost:80",
expNetwork: "tcp",
expAddress: "127.0.0.1:80",
},
{
name: "localhost ipv6 and port",
addressString: "[::1]:80",
expNetwork: "tcp",
expAddress: "[::1]:80",
},
{
name: "colon and port",
addressString: ":80",
expNetwork: "tcp",
expAddress: ":80",
},
{
name: "colon only",
addressString: ":",
expNetwork: "tcp",
expAddress: ":0",
},
{
name: "localhost and path",
addressString: "localhost/path",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost port and path",
addressString: "localhost:80/path",
expNetwork: "tcp",
expAddress: "127.0.0.1:80",
},
{
name: "unix prefix",
addressString: "unix://the/rest/of/the/path",
expNetwork: "unix",
expAddress: "the/rest/of/the/path",
},
{
name: "unix prefix",
addressString: "unixpacket://the/rest/of/the/path",
expNetwork: "unixpacket",
expAddress: "the/rest/of/the/path",
},
{
name: "error http prefix",
addressString: "http://localhost:1010",
expErrStr: "unsupported protocol in address",
},
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
addr, err := ParseAddressString(tc.addressString)
if tc.expErrStr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expErrStr)
return
}
require.NoError(t, err)
require.Equal(t, tc.expNetwork, addr.Network())
require.Equal(t, tc.expAddress, addr.String())
})
}
}

0 comments on commit 47faac8

Please sign in to comment.