Skip to content

Commit

Permalink
Fix support for punycode and non-default ports in OIDC Discovery Prov…
Browse files Browse the repository at this point in the history
…ider (#2453)

This change fixes the recently added domain validation in the OIDC Discovery
Provider to adequately accomodate punycode and properly account for
host:port values in the `Host` field.

It also rewords some of the error messages produced by misconfiguration.

Signed-off-by: Andrew Harding <[email protected]>
  • Loading branch information
azdagron authored Aug 25, 2021
1 parent 63c382f commit 0b1ecea
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 49 deletions.
17 changes: 12 additions & 5 deletions support/oidc-discovery-provider/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ The configuration file is **required** by the provider. It contains
| Key | Type | Required? | Description | Default |
| ---------------------- | --------| -------------- | -------------------------------------------------------- | -------- |
| `acme` | section | required[1] | Provides the ACME configuration. | |
| `domain` | string | required | The domain the provider is being served from. | |
| `insecure_addr` | string | optional[3] | Exposes the service on http. | |
| `allow_insecure_scheme` | string | optional[3] | Serves OIDC configuration response with HTTP url. | `false` |
| `domains` | strings | required | One or more domains the provider is being served from. | |
| `insecure_addr` | string | optional[3] | Exposes the service on http. | |
| `listen_socket_path` | string | required[1][3] | Path on disk to listen with a Unix Domain Socket. | |
| `log_format` | string | optional | Format of the logs (either `"TEXT"` or `"JSON"`) | `""` |
| `log_level` | string | required | Log level (one of `"error"`,`"warn"`,`"info"`,`"debug"`) | `"info"` |
Expand All @@ -51,6 +51,13 @@ The configuration file is **required** by the provider. It contains

[3]: The `allow_insecure_scheme` should only be used in a local development environment for testing purposes. It only works in conjunction with `insecure_addr` or `listen_socket_path`.

The `domains` configurable contains the list of domains the provider is
expected to be served from. If a request is received from a domain other than
one in the list (as determined by the Host or X-Forwarded-Host header), it
will be rejected. Likewise, when ACME is used, the `domains` list contains the
allowed domains for which certificates will be obtained. The TLS handshake
will terminate if another domain is requested.

#### ACME Section

| Key | Type | Required? | Description | Default |
Expand Down Expand Up @@ -88,7 +95,7 @@ The configuration file is **required** by the provider. It contains

```
log_level = "debug"
domain = "mypublicdomain.test"
domains = ["mypublicdomain.test"]
acme {
cache_dir = "/some/path/on/disk/to/cache/creds"
tos_accepted = true
Expand All @@ -102,7 +109,7 @@ server_api {

```
log_level = "debug"
domain = "mypublicdomain.test"
domains = ["mypublicdomain.test"]
acme {
cache_dir = "/some/path/on/disk/to/cache/creds"
tos_accepted = true
Expand All @@ -122,7 +129,7 @@ Nginx, Apache, or Envoy which supports reverse proxying to a unix socket.

```
log_level = "debug"
domain = "mypublicdomain.test"
domains = ["mypublicdomain.test"]
listen_socket_path = "/run/oidc-discovery-provider/server.sock"
workload_api {
Expand Down
16 changes: 8 additions & 8 deletions support/oidc-discovery-provider/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Config struct {
// Domain is the domain this provider will be hosted under. It is used
// when obtaining certs via ACME (unless ListenSocketPath is specified).
// Deprecated. Domains should be used instead.
// Deprecated: remove in 1.2.0
Domain string `hcl:"domain"`

// Domains are the domains this provider will be hosted under. Incoming requests
Expand Down Expand Up @@ -156,17 +157,16 @@ func ParseConfig(hclConfig string) (_ *Config, err error) {
c.LogLevel = defaultLogLevel
}

if c.Domain == "" && len(c.Domains) == 0 {
switch {
case c.Domain == "" && len(c.Domains) == 0:
return nil, errs.New("at least one domain must be configured")
}
if len(c.Domains) > 0 && c.Domain != "" {
return nil, errs.New("use `domains` configurable only, `domain` configurable is deprecated")
}

if c.Domain != "" {
case c.Domain != "" && len(c.Domains) == 0:
c.Domains = []string{c.Domain}
case c.Domain == "" && len(c.Domains) > 0:
c.Domains = dedupeList(c.Domains)
case c.Domain != "" && len(c.Domains) > 0:
return nil, errs.New("domain is deprecated and will be removed in a future release; please use domains instead")
}
c.Domains = dedupeList(c.Domains)

if c.ACME != nil {
c.ACME.CacheDir = defaultCacheDir
Expand Down
2 changes: 1 addition & 1 deletion support/oidc-discovery-provider/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestParseConfig(t *testing.T) {
socket_path = "/some/socket/path"
}
`,
err: "use `domains` configurable only, `domain` configurable is deprecated",
err: "domain is deprecated and will be removed in a future release; please use domains instead",
},
{
name: "no ACME configuration",
Expand Down
50 changes: 50 additions & 0 deletions support/oidc-discovery-provider/domain_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package main

import (
"fmt"

"golang.org/x/net/idna"
)

type DomainPolicy = func(domain string) error

// DomainAllowlist returns a policy that allows any domain in the given domains
func DomainAllowlist(domains ...string) (DomainPolicy, error) {
allowlist := make(map[string]struct{}, len(domains))
for _, domain := range domains {
domainKey, err := toDomainKey(domain)
if err != nil {
return nil, err
}
allowlist[domainKey] = struct{}{}
}
return func(domain string) error {
domainKey, err := toDomainKey(domain)
if err != nil {
return err
}
if _, allowed := allowlist[domainKey]; !allowed {
return fmt.Errorf("domain %q is not allowed", domain)
}
return nil
}, nil
}

// AllowAnyDomain returns a policy that allows any domain
func AllowAnyDomain() DomainPolicy {
return func(domain string) error {
_, err := toDomainKey(domain)
return err
}
}

func toDomainKey(domain string) (string, error) {
punycode, err := idna.Lookup.ToASCII(domain)
if err != nil {
return "", fmt.Errorf("domain %q is not a valid domain name: %w", domain, err)
}
if punycode != domain {
return "", fmt.Errorf("domain %q must already be punycode encoded", domain)
}
return domain, nil
}
49 changes: 49 additions & 0 deletions support/oidc-discovery-provider/domain_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package main

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDomainAllowlist(t *testing.T) {
t.Run("unicode", func(t *testing.T) {
_, err := DomainAllowlist("😬.test")
assert.EqualError(t, err, `domain "😬.test" must already be punycode encoded`)
})

t.Run("punycode", func(t *testing.T) {
policy, err := DomainAllowlist("xn--n38h.test")
require.NoError(t, err)
assert.EqualError(t, policy("😬.test"), `domain "😬.test" must already be punycode encoded`)
assert.NoError(t, policy("xn--n38h.test"))
assert.EqualError(t, policy("bad.test"), `domain "bad.test" is not allowed`)
})

t.Run("ascii", func(t *testing.T) {
policy, err := DomainAllowlist("ascii.test")
require.NoError(t, err)
assert.NoError(t, policy("ascii.test"))
assert.EqualError(t, policy("bad.test"), `domain "bad.test" is not allowed`)
})

t.Run("invalid domain in config", func(t *testing.T) {
_, err := DomainAllowlist("invalid/domain.test")
assert.EqualError(t, err, `domain "invalid/domain.test" is not a valid domain name: idna: disallowed rune U+002F`)
})

t.Run("invalid domain on lookup", func(t *testing.T) {
policy, err := DomainAllowlist()
require.NoError(t, err)
assert.EqualError(t, policy("invalid/domain.test"), `domain "invalid/domain.test" is not a valid domain name: idna: disallowed rune U+002F`)
})
}

func TestAllowAnyDomain(t *testing.T) {
policy := AllowAnyDomain()
assert.NoError(t, policy("foo"))
assert.NoError(t, policy("bar"))
assert.NoError(t, policy("baz"))
assert.EqualError(t, policy("invalid/domain.test"), `domain "invalid/domain.test" is not a valid domain name: idna: disallowed rune U+002F`)
}
35 changes: 19 additions & 16 deletions support/oidc-discovery-provider/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,26 @@ package main
import (
"bytes"
"encoding/json"
"net"
"net/http"
"net/url"

"github.com/gorilla/handlers"
)

type Handler struct {
domains map[string]struct{}
source JWKSSource
domainPolicy DomainPolicy
allowInsecureScheme bool
performDomainCheck bool

http.Handler
}

func NewHandler(domains []string, source JWKSSource, allowInsecureScheme bool, performDomainCheck bool) *Handler {
allowedDomains := make(map[string]struct{})

for _, d := range domains {
allowedDomains[d] = struct{}{}
}

func NewHandler(domainPolicy DomainPolicy, source JWKSSource, allowInsecureScheme bool) *Handler {
h := &Handler{
domains: allowedDomains,
domainPolicy: domainPolicy,
source: source,
allowInsecureScheme: allowInsecureScheme,
performDomainCheck: performDomainCheck,
}

mux := http.NewServeMux()
Expand All @@ -46,11 +39,9 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) {
return
}

if h.performDomainCheck {
if _, ok := h.domains[r.Host]; !ok {
http.Error(w, "domain not allowed", http.StatusNotFound)
return
}
if err := h.verifyHost(r.Host); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

urlScheme := "https"
Expand Down Expand Up @@ -125,3 +116,15 @@ func (h *Handler) serveKeys(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
http.ServeContent(w, r, "keys", modTime, bytes.NewReader(jwksBytes))
}

func (h *Handler) verifyHost(host string) error {
// Obtain the domain name from the host value, which comes from the
// request, or is pulled from the X-Forwarded-Host header (via the
// ProxyHeaders middleware). The value may be in host or host:port form.
domain, _, err := net.SplitHostPort(host)
if err != nil {
// `Host` was not in the host:port form form.
domain = host
}
return h.domainPolicy(domain)
}
Loading

0 comments on commit 0b1ecea

Please sign in to comment.