From 47faac886293ad541f47d4a2fb480469843c7b48 Mon Sep 17 00:00:00 2001 From: Sam Eiderman Date: Sun, 27 Oct 2024 14:19:17 +0200 Subject: [PATCH] Fix non-root hosts failing on resolving DNS A fix for a bug introduced by #2168 Previously, config.Host worked in the following way: 1. Documented as supporting ip addresses only 2. In fact supported "host/path" syntax 3. Did not support "scheme" prefixes, i.e. https:// Not sure this is the desired approach, probably the best thing would have been to extend config to contain "Scheme" and "Path" fields as well. However, this was the way it worked. 1. Now Host can contain scheme prefixes "unix://..." 2. Host can no longer contain ".../path" This PR solves this behavior while maintaining support of the "unix://" flow as well. For some reason, "scheme" is named "network" in #2168 - I did not change that. Also remove disambiguation in "network:address:port", where it parsed "myhost:8888" as network:address instead address:port. --- rpcclient/infrastructure.go | 127 +++++++++++++++++-------------- rpcclient/infrastructure_test.go | 110 ++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 58 deletions(-) create mode 100644 rpcclient/infrastructure_test.go diff --git a/rpcclient/infrastructure.go b/rpcclient/infrastructure.go index 4fe1d894df..9f570cec8e 100644 --- a/rpcclient/infrastructure.go +++ b/rpcclient/infrastructure.go @@ -759,18 +759,13 @@ out: // result, unmarshalling it, and delivering the unmarshalled result to the // provided response channel. func (c *Client) handleSendPostMessage(jReq *jsonRequest) { - protocol := "http" - if !c.config.DisableTLS { - protocol = "https" - } - var ( - err, lastErr error + lastErr error backoff time.Duration httpResponse *http.Response ) - parsedAddr, err := ParseAddressString(c.config.Host) + httpURL, err := c.config.httpURL() if err != nil { jReq.responseChan <- &Response{ err: fmt.Errorf("failed to parse address %v", err), @@ -778,22 +773,12 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { return } - var url string - switch parsedAddr.Network() { - case "unix", "unixpacket": - // Using a placeholder URL because a non-empty URL is required. - // The Unix domain socket is specified in the DialContext. - url = protocol + "://unix" - default: - url = protocol + "://" + c.config.Host - } - tries := 10 for i := 0; i < tries; i++ { var httpReq *http.Request bodyReader := bytes.NewReader(jReq.marshalledJSON) - httpReq, err = http.NewRequest("POST", url, bodyReader) + httpReq, err = http.NewRequest("POST", httpURL, bodyReader) if err != nil { jReq.responseChan <- &Response{result: nil, err: err} return @@ -1355,7 +1340,7 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) { } } - parsedAddr, err := ParseAddressString(config.Host) + parsedDialAddr, err := ParseAddressString(config.Host) if err != nil { return nil, err } @@ -1363,8 +1348,13 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) { Transport: &http.Transport{ Proxy: proxyFunc, TLSClientConfig: tlsConfig, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial(parsedAddr.Network(), parsedAddr.String()) + DialContext: func(_ context.Context, _, + _ string) (net.Conn, error) { + + return net.Dial( + parsedDialAddr.Network(), + parsedDialAddr.String(), + ) }, }, Timeout: defaultHTTPTimeout, @@ -1373,6 +1363,32 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) { return &client, nil } +// httpURL returns the URL to use for HTTP POST requests. +func (config *ConnConfig) httpURL() (string, error) { + protocol := "http" + if !config.DisableTLS { + protocol = "https" + } + + parsedAddr, err := ParseAddressString(config.Host) + if err != nil { + return "", fmt.Errorf("error parsing host '%v': %v", + config.Host, err) + } + + var httpURL string + switch parsedAddr.Network() { + case "unix", "unixpacket": + // Using a placeholder URL because a non-empty URL is required. + // The Unix domain socket is specified in the DialContext. + httpURL = protocol + "://unix" + default: + httpURL = protocol + "://" + config.Host + } + + return httpURL, nil +} + // dial opens a websocket connection using the passed connection configuration // details. func dial(config *ConnConfig) (*websocket.Conn, error) { @@ -1733,53 +1749,48 @@ func (c *Client) Send() error { return nil } +// cutPrefix returns s without the provided leading prefix string +// and reports whether it found the prefix. +// If s doesn't start with prefix, CutPrefix returns s, false. +// If prefix is the empty string, CutPrefix returns s, true. +// Copied from go1.20 version. +func cutPrefix(s, prefix string) (after string, found bool) { + if !strings.HasPrefix(s, prefix) { + return s, false + } + return s[len(prefix):], true +} + // ParseAddressString converts an address in string format to a net.Addr that is // compatible with btcd. UDP is not supported because btcd needs reliable -// connections. We accept a custom function to resolve any TCP addresses so -// that caller is able control exactly how resolution is performed. +// connections. func ParseAddressString(strAddress string) (net.Addr, error) { - var parsedNetwork, parsedAddr string + // Addresses can either be in unix://address, unixpacket://address URL + // format, or just address:port host format for tcp. + if after, ok := cutPrefix(strAddress, "unix://"); ok { + return net.ResolveUnixAddr("unix", after) + } + if after, ok := cutPrefix(strAddress, "unixpacket://"); ok { + return net.ResolveUnixAddr("unixpacket", after) + } - // Addresses can either be in network://address:port format, - // network:address:port, address:port, or just port. We want to support - // all possible types. if strings.Contains(strAddress, "://") { - parts := strings.Split(strAddress, "://") - parsedNetwork, parsedAddr = parts[0], parts[1] - } else if strings.Contains(strAddress, ":") { - parts := strings.Split(strAddress, ":") - parsedNetwork = parts[0] - parsedAddr = strings.Join(parts[1:], ":") - } else { - parsedAddr = strAddress + // Not supporting :// anywhere in the host or path. + return nil, fmt.Errorf("unsupported protocol in address: %s", + strAddress) } - // Only TCP and Unix socket addresses are valid. We can't use IP or - // UDP only connections for anything we do in lnd. - switch parsedNetwork { - case "unix", "unixpacket": - return net.ResolveUnixAddr(parsedNetwork, parsedAddr) - - case "tcp", "tcp4", "tcp6": - return net.ResolveTCPAddr(parsedNetwork, verifyPort(parsedAddr)) - - case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram": - return nil, fmt.Errorf("only TCP or unix socket "+ - "addresses are supported: %s", parsedAddr) - - default: - // We'll now possibly use the local host short circuit - // or parse out an all interfaces listen. - addrWithPort := verifyPort(strAddress) - - // Otherwise, we'll attempt to resolve the host. - return net.ResolveTCPAddr("tcp", addrWithPort) + // Parse it as a dummy URL to get the host and port. + u, err := url.Parse("dummy://" + strAddress) + if err != nil { + return nil, err } + return net.ResolveTCPAddr("tcp", verifyPort(u.Host)) } // verifyPort makes sure that an address string has both a host and a port. // If the address is just a port, then we'll assume that the user is using the -// short cut to specify a localhost:port address. +// shortcut to specify a localhost:port address. func verifyPort(address string) string { host, port, err := net.SplitHostPort(address) if err != nil { @@ -1801,8 +1812,8 @@ func verifyPort(address string) string { return net.JoinHostPort(address, "") } - // In the case that both the host and port are empty, we'll use the - // an empty port. + // In the case that both the host and port are empty, we'll use an empty + // port. if host == "" && port == "" { return ":" } diff --git a/rpcclient/infrastructure_test.go b/rpcclient/infrastructure_test.go new file mode 100644 index 0000000000..8416b7ad3c --- /dev/null +++ b/rpcclient/infrastructure_test.go @@ -0,0 +1,110 @@ +package rpcclient + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestParseAddressString checks different variation of supported and +// unsupported addresses. +func TestParseAddressString(t *testing.T) { + t.Parallel() + + // Using localhost only to avoid network calls. + testCases := []struct { + name string + addressString string + expNetwork string + expAddress string + expErrStr string + }{ + { + name: "localhost", + addressString: "localhost", + expNetwork: "tcp", + expAddress: "127.0.0.1:0", + }, + { + name: "localhost ip", + addressString: "127.0.0.1", + expNetwork: "tcp", + expAddress: "127.0.0.1:0", + }, + { + name: "localhost ipv6", + addressString: "::1", + expNetwork: "tcp", + expAddress: "[::1]:0", + }, + { + name: "localhost and port", + addressString: "localhost:80", + expNetwork: "tcp", + expAddress: "127.0.0.1:80", + }, + { + name: "localhost ipv6 and port", + addressString: "[::1]:80", + expNetwork: "tcp", + expAddress: "[::1]:80", + }, + { + name: "colon and port", + addressString: ":80", + expNetwork: "tcp", + expAddress: ":80", + }, + { + name: "colon only", + addressString: ":", + expNetwork: "tcp", + expAddress: ":0", + }, + { + name: "localhost and path", + addressString: "localhost/path", + expNetwork: "tcp", + expAddress: "127.0.0.1:0", + }, + { + name: "localhost port and path", + addressString: "localhost:80/path", + expNetwork: "tcp", + expAddress: "127.0.0.1:80", + }, + { + name: "unix prefix", + addressString: "unix://the/rest/of/the/path", + expNetwork: "unix", + expAddress: "the/rest/of/the/path", + }, + { + name: "unix prefix", + addressString: "unixpacket://the/rest/of/the/path", + expNetwork: "unixpacket", + expAddress: "the/rest/of/the/path", + }, + { + name: "error http prefix", + addressString: "http://localhost:1010", + expErrStr: "unsupported protocol in address", + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + addr, err := ParseAddressString(tc.addressString) + if tc.expErrStr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrStr) + return + } + require.NoError(t, err) + require.Equal(t, tc.expNetwork, addr.Network()) + require.Equal(t, tc.expAddress, addr.String()) + }) + } +}