Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v17] Sanitize SSH server hostnames #49091

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ import (
"log/slog"
"math/big"
insecurerand "math/rand"
"net"
"os"
"regexp"
"slices"
"sort"
"strconv"
Expand Down Expand Up @@ -1470,6 +1472,28 @@ func (a *Server) runPeriodicOperations() {
if services.NodeHasMissedKeepAlives(srv) {
missedKeepAliveCount++
}

// TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then.
if !validServerHostname(srv.GetHostname()) {
if srv.GetSubKind() != types.SubKindOpenSSHNode {
return false, nil
}

logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname())

logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname")
// Any existing static hosts will not have their
// hostname sanitized since they don't heartbeat.
if err := sanitizeHostname(srv); err != nil {
logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err)
return false, nil
}

if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err)
}
}

return false, nil
},
req,
Expand Down Expand Up @@ -5618,9 +5642,68 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error {
return nil
}

const (
serverHostnameMaxLen = 256
serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$`
replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname"
)

var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern)

// validServerHostname returns false if the hostname is longer than 256 characters or
// does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also
// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol.
func validServerHostname(hostname string) bool {
return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname)
}

func sanitizeHostname(server types.Server) error {
invalidHostname := server.GetHostname()

replacedHostname := server.GetName()
if server.GetSubKind() == types.SubKindOpenSSHNode {
host, _, err := net.SplitHostPort(server.GetAddr())
if err != nil || !validServerHostname(host) {
id, err := uuid.NewRandom()
if err != nil {
return trace.Wrap(err)
}

host = id.String()
}

replacedHostname = host
}

switch s := server.(type) {
case *types.ServerV2:
s.Spec.Hostname = replacedHostname

if s.Metadata.Labels == nil {
s.Metadata.Labels = map[string]string{}
}

s.Metadata.Labels[replacedHostnameLabel] = invalidHostname
default:
return trace.BadParameter("invalid server provided")
}

return nil
}

// UpsertNode implements [services.Presence] by delegating to [Server.Services]
// and potentially emitting a [usagereporter] event.
func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) {
if !validServerHostname(server.GetHostname()) {
a.logger.DebugContext(a.closeCtx, "sanitizing invalid server hostname",
"server", server.GetName(),
"hostname", server.GetHostname(),
)
if err := sanitizeHostname(server); err != nil {
return nil, trace.Wrap(err)
}
}

lease, err := a.Services.UpsertNode(ctx, server)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
138 changes: 128 additions & 10 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package auth

import (
"cmp"
"context"
"crypto/rand"
"crypto/x509"
Expand All @@ -34,7 +35,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
gocmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/license"
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestSessions(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, out.GetSSHPriv())
assert.Empty(t, out.GetTLSPriv())
assert.Empty(t, cmp.Diff(ws, out,
assert.Empty(t, gocmp.Diff(ws, out,
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
cmpopts.IgnoreFields(types.WebSessionSpecV2{}, "Priv", "TLSPriv")))

Expand Down Expand Up @@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) {
AssetTag: test.opts.DeviceExtensions.AssetTag,
CredentialId: test.opts.DeviceExtensions.CredentialID,
}
if diff := cmp.Diff(want, got); diff != "" {
if diff := gocmp.Diff(want, got); diff != "" {
t.Errorf("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n%s", diff)
}
}
Expand Down Expand Up @@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) {
// Assert TLS extensions.
_, newIdentity := parseX509PEMAndIdentity(t, newSession.GetTLSCert())
wantExts := tlsca.DeviceExtensions(*deviceExts)
if diff := cmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
if diff := gocmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
t.Errorf("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n%s", diff)
}

// Assert SSH extensions.
if diff := cmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
if diff := gocmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
t.Errorf("newSession.Pub DeviceExtensions mismatch (-want +got)\n%s", diff)
}
})
Expand Down Expand Up @@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) {
// Validate audit event.
lastEvent := p.mockEmitter.LastEvent()
require.IsType(t, &apievents.CertificateCreate{}, lastEvent)
require.Empty(t, cmp.Diff(
require.Empty(t, gocmp.Diff(
&apievents.CertificateCreate{
Metadata: apievents.Metadata{
Type: events.CertificateCreateEvent,
Expand Down Expand Up @@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD
}

// Ignore LastUsed and SignatureCounter?
var opts []cmp.Option
var opts []gocmp.Option
if ignoreUpdateAndCounter {
opts = append(opts, cmp.FilterPath(func(path cmp.Path) bool {
opts = append(opts, gocmp.FilterPath(func(path gocmp.Path) bool {
p := path.String()
return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter"
}, cmp.Ignore()))
}, gocmp.Ignore()))
}

if diff := cmp.Diff(want, got, opts...); diff != "" {
if diff := gocmp.Diff(want, got, opts...); diff != "" {
t.Errorf("compareDevices mismatch (-want +got):\n%s", diff)
}
}
Expand Down Expand Up @@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest

return &notification
}

// TestServerHostnameSanitization tests that persisting servers with
// "invalid" hostnames results in the hostname being sanitized and the
// illegal name being placed in a label.
func TestServerHostnameSanitization(t *testing.T) {
t.Parallel()
ctx := context.Background()
srv, err := NewTestAuthServer(TestAuthServerConfig{Dir: t.TempDir()})
require.NoError(t, err)

cases := []struct {
name string
hostname string
addr string
invalidHostname bool
invalidAddr bool
}{
{
name: "valid dns hostname",
hostname: "llama.example.com",
},
{
name: "valid friendly hostname",
hostname: "llama",
},
{
name: "uuid hostname",
hostname: uuid.NewString(),
},
{
name: "uuid dns hostname",
hostname: uuid.NewString() + ".example.com",
},
{
name: "empty hostname",
hostname: "",
invalidHostname: true,
},
{
name: "exceptionally long hostname",
hostname: strings.Repeat("a", serverHostnameMaxLen*2),
invalidHostname: true,
},
{
name: "invalid dns hostname",
hostname: "llama..example.com",
invalidHostname: true,
},
{
name: "spaces in hostname",
hostname: "the quick brown fox jumps over the lazy dog",
invalidHostname: true,
},
{
name: "invalid addr",
hostname: "..",
addr: "..:2345",
invalidHostname: true,
invalidAddr: true,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
for _, subKind := range []string{types.KindNode, types.SubKindOpenSSHNode} {
t.Run(subKind, func(t *testing.T) {
server := &types.ServerV2{
Kind: types.KindNode,
SubKind: subKind,
Metadata: types.Metadata{
Name: uuid.NewString(),
},
Spec: types.ServerSpecV2{
Hostname: test.hostname,
Addr: cmp.Or(test.addr, "abcd:1234"),
},
}
if subKind == types.KindNode {
server.SubKind = ""
}

_, err = srv.AuthServer.UpsertNode(ctx, server)
require.NoError(t, err)

replacedValue, _ := server.GetLabel("teleport.internal/invalid-hostname")
if !test.invalidHostname {
assert.Equal(t, test.hostname, server.GetHostname())
assert.Empty(t, replacedValue)
return
}

assert.Equal(t, test.hostname, replacedValue)
switch subKind {
case types.SubKindOpenSSHNode:
host, _, err := net.SplitHostPort(server.GetAddr())
assert.NoError(t, err)
if !test.invalidAddr {
// If the address is valid, then the hostname should be set
// to the host of the addr field.
assert.Equal(t, host, server.GetHostname())
} else {
// If the address is not valid, then the hostname should be
// set to a UUID.
assert.NotEqual(t, host, server.GetHostname())
assert.NotEqual(t, server.GetName(), server.GetHostname())

_, err := uuid.Parse(server.GetHostname())
require.NoError(t, err)
}
default:
assert.Equal(t, server.GetName(), server.GetHostname())
}
})
}
})
}
}
4 changes: 2 additions & 2 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2955,9 +2955,9 @@ func TestNodesCRUD(t *testing.T) {
require.NoError(t, err)

// node1 and node2 will be added to default namespace
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{}, nil)
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{Hostname: "node1"}, nil)
require.NoError(t, err)
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{}, nil)
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{Hostname: "node2"}, nil)
require.NoError(t, err)

t.Run("CreateNode", func(t *testing.T) {
Expand Down
28 changes: 28 additions & 0 deletions lib/services/local/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,34 @@ func (s *PresenceService) UpsertNode(ctx context.Context, server types.Server) (
}, nil
}

// UpdateNode conditionally updates the provided server.
func (s *PresenceService) UpdateNode(ctx context.Context, server types.Server) (types.Server, error) {
if server.GetNamespace() == "" {
server.SetNamespace(apidefaults.Namespace)
}

if n := server.GetNamespace(); n != apidefaults.Namespace {
return nil, trace.BadParameter("cannot place node in namespace %q, custom namespaces are deprecated", n)
}
rev := server.GetRevision()
value, err := services.MarshalServer(server)
if err != nil {
return nil, trace.Wrap(err)
}
lease, err := s.ConditionalUpdate(ctx, backend.Item{
Key: backend.NewKey(nodesPrefix, server.GetNamespace(), server.GetName()),
Value: value,
Expires: server.Expiry(),
Revision: rev,
})
if err != nil {
return nil, trace.Wrap(err)
}

server.SetRevision(lease.Revision)
return server, nil
}

// GetAuthServers returns a list of registered servers
func (s *PresenceService) GetAuthServers() ([]types.Server, error) {
return s.getServers(context.TODO(), types.KindAuthServer, authServersPrefix)
Expand Down
25 changes: 25 additions & 0 deletions lib/services/local/presence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,31 @@ func TestNodeCRUD(t *testing.T) {
require.NoError(t, err)
})

t.Run("UpdateNode", func(t *testing.T) {
node1, err = presence.GetNode(ctx, apidefaults.Namespace, node1.GetName())
require.NoError(t, err)
node1.SetAddr("1.2.3.4:8080")

node2, err = presence.GetNode(ctx, apidefaults.Namespace, node2.GetName())
require.NoError(t, err)

node1, err = presence.UpdateNode(ctx, node1)
require.NoError(t, err)
require.Equal(t, "1.2.3.4:8080", node1.GetAddr())

rev := node2.GetRevision()
node2.SetAddr("1.2.3.4:9090")
node2.SetRevision(node1.GetRevision())

_, err = presence.UpdateNode(ctx, node2)
require.True(t, trace.IsCompareFailed(err))
node2.SetRevision(rev)

node2, err = presence.UpdateNode(ctx, node2)
require.NoError(t, err)
require.Equal(t, "1.2.3.4:9090", node2.GetAddr())
})

// Run NodeGetters in nested subtests to allow parallelization.
t.Run("NodeGetters", func(t *testing.T) {
t.Run("GetNodes", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions lib/services/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,5 @@ type PresenceInternal interface {
UpsertHostUserInteractionTime(ctx context.Context, name string, loginTime time.Time) error
GetHostUserInteractionTime(ctx context.Context, name string) (time.Time, error)
UpsertReverseTunnelV2(ctx context.Context, tunnel types.ReverseTunnel) (types.ReverseTunnel, error)
UpdateNode(ctx context.Context, server types.Server) (types.Server, error)
}
Loading
Loading