Skip to content

Commit 723f751

Browse files
authoredNov 15, 2024··
[v17] Sanitize SSH server hostnames (#49091)
* Sanitize SSH server hostnames Prevents any invalid and malicious hostnames, but replacing them with known valid data already associated with the host. This was chosen instead of rejecting to persist the server resource in an attempt to continue providing access to the host in order to remedy the invalid hostname. Any servers that represent a Teleport ssh_service with an invalid hostname will be replaced by the host UUID. Any static OpenSSH servers will have invalid hostnames replaced with the address. This will continue to allow the hosts to be dialable. In order to make these hosts discoverable, the invalid hostname will be set in the "teleport.internal/invalid-hostname" label. Updates gravitational/teleport-private#1676. * add and use internal update node method * add test coverage for UpdateNode
·
1 parent 60c78b2 commit 723f751

File tree

8 files changed

+273
-14
lines changed

8 files changed

+273
-14
lines changed
 

‎lib/auth/auth.go

+83
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ import (
3939
"log/slog"
4040
"math/big"
4141
insecurerand "math/rand"
42+
"net"
4243
"os"
44+
"regexp"
4345
"slices"
4446
"sort"
4547
"strconv"
@@ -1470,6 +1472,28 @@ func (a *Server) runPeriodicOperations() {
14701472
if services.NodeHasMissedKeepAlives(srv) {
14711473
missedKeepAliveCount++
14721474
}
1475+
1476+
// TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then.
1477+
if !validServerHostname(srv.GetHostname()) {
1478+
if srv.GetSubKind() != types.SubKindOpenSSHNode {
1479+
return false, nil
1480+
}
1481+
1482+
logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname())
1483+
1484+
logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname")
1485+
// Any existing static hosts will not have their
1486+
// hostname sanitized since they don't heartbeat.
1487+
if err := sanitizeHostname(srv); err != nil {
1488+
logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err)
1489+
return false, nil
1490+
}
1491+
1492+
if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
1493+
logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err)
1494+
}
1495+
}
1496+
14731497
return false, nil
14741498
},
14751499
req,
@@ -5618,9 +5642,68 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error {
56185642
return nil
56195643
}
56205644

5645+
const (
5646+
serverHostnameMaxLen = 256
5647+
serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$`
5648+
replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname"
5649+
)
5650+
5651+
var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern)
5652+
5653+
// validServerHostname returns false if the hostname is longer than 256 characters or
5654+
// does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also
5655+
// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol.
5656+
func validServerHostname(hostname string) bool {
5657+
return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname)
5658+
}
5659+
5660+
func sanitizeHostname(server types.Server) error {
5661+
invalidHostname := server.GetHostname()
5662+
5663+
replacedHostname := server.GetName()
5664+
if server.GetSubKind() == types.SubKindOpenSSHNode {
5665+
host, _, err := net.SplitHostPort(server.GetAddr())
5666+
if err != nil || !validServerHostname(host) {
5667+
id, err := uuid.NewRandom()
5668+
if err != nil {
5669+
return trace.Wrap(err)
5670+
}
5671+
5672+
host = id.String()
5673+
}
5674+
5675+
replacedHostname = host
5676+
}
5677+
5678+
switch s := server.(type) {
5679+
case *types.ServerV2:
5680+
s.Spec.Hostname = replacedHostname
5681+
5682+
if s.Metadata.Labels == nil {
5683+
s.Metadata.Labels = map[string]string{}
5684+
}
5685+
5686+
s.Metadata.Labels[replacedHostnameLabel] = invalidHostname
5687+
default:
5688+
return trace.BadParameter("invalid server provided")
5689+
}
5690+
5691+
return nil
5692+
}
5693+
56215694
// UpsertNode implements [services.Presence] by delegating to [Server.Services]
56225695
// and potentially emitting a [usagereporter] event.
56235696
func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) {
5697+
if !validServerHostname(server.GetHostname()) {
5698+
a.logger.DebugContext(a.closeCtx, "sanitizing invalid server hostname",
5699+
"server", server.GetName(),
5700+
"hostname", server.GetHostname(),
5701+
)
5702+
if err := sanitizeHostname(server); err != nil {
5703+
return nil, trace.Wrap(err)
5704+
}
5705+
}
5706+
56245707
lease, err := a.Services.UpsertNode(ctx, server)
56255708
if err != nil {
56265709
return nil, trace.Wrap(err)

‎lib/auth/auth_test.go

+128-10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package auth
2020

2121
import (
22+
"cmp"
2223
"context"
2324
"crypto/rand"
2425
"crypto/x509"
@@ -34,7 +35,7 @@ import (
3435
"testing"
3536
"time"
3637

37-
"github.com/google/go-cmp/cmp"
38+
gocmp "github.com/google/go-cmp/cmp"
3839
"github.com/google/go-cmp/cmp/cmpopts"
3940
"github.com/google/uuid"
4041
"github.com/gravitational/license"
@@ -307,7 +308,7 @@ func TestSessions(t *testing.T) {
307308
require.NoError(t, err)
308309
assert.Empty(t, out.GetSSHPriv())
309310
assert.Empty(t, out.GetTLSPriv())
310-
assert.Empty(t, cmp.Diff(ws, out,
311+
assert.Empty(t, gocmp.Diff(ws, out,
311312
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
312313
cmpopts.IgnoreFields(types.WebSessionSpecV2{}, "Priv", "TLSPriv")))
313314

@@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) {
16551656
AssetTag: test.opts.DeviceExtensions.AssetTag,
16561657
CredentialId: test.opts.DeviceExtensions.CredentialID,
16571658
}
1658-
if diff := cmp.Diff(want, got); diff != "" {
1659+
if diff := gocmp.Diff(want, got); diff != "" {
16591660
t.Errorf("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n%s", diff)
16601661
}
16611662
}
@@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) {
23012302
// Assert TLS extensions.
23022303
_, newIdentity := parseX509PEMAndIdentity(t, newSession.GetTLSCert())
23032304
wantExts := tlsca.DeviceExtensions(*deviceExts)
2304-
if diff := cmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
2305+
if diff := gocmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
23052306
t.Errorf("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n%s", diff)
23062307
}
23072308

23082309
// Assert SSH extensions.
2309-
if diff := cmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
2310+
if diff := gocmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
23102311
t.Errorf("newSession.Pub DeviceExtensions mismatch (-want +got)\n%s", diff)
23112312
}
23122313
})
@@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) {
25452546
// Validate audit event.
25462547
lastEvent := p.mockEmitter.LastEvent()
25472548
require.IsType(t, &apievents.CertificateCreate{}, lastEvent)
2548-
require.Empty(t, cmp.Diff(
2549+
require.Empty(t, gocmp.Diff(
25492550
&apievents.CertificateCreate{
25502551
Metadata: apievents.Metadata{
25512552
Type: events.CertificateCreateEvent,
@@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD
38013802
}
38023803

38033804
// Ignore LastUsed and SignatureCounter?
3804-
var opts []cmp.Option
3805+
var opts []gocmp.Option
38053806
if ignoreUpdateAndCounter {
3806-
opts = append(opts, cmp.FilterPath(func(path cmp.Path) bool {
3807+
opts = append(opts, gocmp.FilterPath(func(path gocmp.Path) bool {
38073808
p := path.String()
38083809
return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter"
3809-
}, cmp.Ignore()))
3810+
}, gocmp.Ignore()))
38103811
}
38113812

3812-
if diff := cmp.Diff(want, got, opts...); diff != "" {
3813+
if diff := gocmp.Diff(want, got, opts...); diff != "" {
38133814
t.Errorf("compareDevices mismatch (-want +got):\n%s", diff)
38143815
}
38153816
}
@@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest
44444445

44454446
return &notification
44464447
}
4448+
4449+
// TestServerHostnameSanitization tests that persisting servers with
4450+
// "invalid" hostnames results in the hostname being sanitized and the
4451+
// illegal name being placed in a label.
4452+
func TestServerHostnameSanitization(t *testing.T) {
4453+
t.Parallel()
4454+
ctx := context.Background()
4455+
srv, err := NewTestAuthServer(TestAuthServerConfig{Dir: t.TempDir()})
4456+
require.NoError(t, err)
4457+
4458+
cases := []struct {
4459+
name string
4460+
hostname string
4461+
addr string
4462+
invalidHostname bool
4463+
invalidAddr bool
4464+
}{
4465+
{
4466+
name: "valid dns hostname",
4467+
hostname: "llama.example.com",
4468+
},
4469+
{
4470+
name: "valid friendly hostname",
4471+
hostname: "llama",
4472+
},
4473+
{
4474+
name: "uuid hostname",
4475+
hostname: uuid.NewString(),
4476+
},
4477+
{
4478+
name: "uuid dns hostname",
4479+
hostname: uuid.NewString() + ".example.com",
4480+
},
4481+
{
4482+
name: "empty hostname",
4483+
hostname: "",
4484+
invalidHostname: true,
4485+
},
4486+
{
4487+
name: "exceptionally long hostname",
4488+
hostname: strings.Repeat("a", serverHostnameMaxLen*2),
4489+
invalidHostname: true,
4490+
},
4491+
{
4492+
name: "invalid dns hostname",
4493+
hostname: "llama..example.com",
4494+
invalidHostname: true,
4495+
},
4496+
{
4497+
name: "spaces in hostname",
4498+
hostname: "the quick brown fox jumps over the lazy dog",
4499+
invalidHostname: true,
4500+
},
4501+
{
4502+
name: "invalid addr",
4503+
hostname: "..",
4504+
addr: "..:2345",
4505+
invalidHostname: true,
4506+
invalidAddr: true,
4507+
},
4508+
}
4509+
4510+
for _, test := range cases {
4511+
t.Run(test.name, func(t *testing.T) {
4512+
for _, subKind := range []string{types.KindNode, types.SubKindOpenSSHNode} {
4513+
t.Run(subKind, func(t *testing.T) {
4514+
server := &types.ServerV2{
4515+
Kind: types.KindNode,
4516+
SubKind: subKind,
4517+
Metadata: types.Metadata{
4518+
Name: uuid.NewString(),
4519+
},
4520+
Spec: types.ServerSpecV2{
4521+
Hostname: test.hostname,
4522+
Addr: cmp.Or(test.addr, "abcd:1234"),
4523+
},
4524+
}
4525+
if subKind == types.KindNode {
4526+
server.SubKind = ""
4527+
}
4528+
4529+
_, err = srv.AuthServer.UpsertNode(ctx, server)
4530+
require.NoError(t, err)
4531+
4532+
replacedValue, _ := server.GetLabel("teleport.internal/invalid-hostname")
4533+
if !test.invalidHostname {
4534+
assert.Equal(t, test.hostname, server.GetHostname())
4535+
assert.Empty(t, replacedValue)
4536+
return
4537+
}
4538+
4539+
assert.Equal(t, test.hostname, replacedValue)
4540+
switch subKind {
4541+
case types.SubKindOpenSSHNode:
4542+
host, _, err := net.SplitHostPort(server.GetAddr())
4543+
assert.NoError(t, err)
4544+
if !test.invalidAddr {
4545+
// If the address is valid, then the hostname should be set
4546+
// to the host of the addr field.
4547+
assert.Equal(t, host, server.GetHostname())
4548+
} else {
4549+
// If the address is not valid, then the hostname should be
4550+
// set to a UUID.
4551+
assert.NotEqual(t, host, server.GetHostname())
4552+
assert.NotEqual(t, server.GetName(), server.GetHostname())
4553+
4554+
_, err := uuid.Parse(server.GetHostname())
4555+
require.NoError(t, err)
4556+
}
4557+
default:
4558+
assert.Equal(t, server.GetName(), server.GetHostname())
4559+
}
4560+
})
4561+
}
4562+
})
4563+
}
4564+
}

‎lib/auth/grpcserver_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -2955,9 +2955,9 @@ func TestNodesCRUD(t *testing.T) {
29552955
require.NoError(t, err)
29562956

29572957
// node1 and node2 will be added to default namespace
2958-
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{}, nil)
2958+
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{Hostname: "node1"}, nil)
29592959
require.NoError(t, err)
2960-
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{}, nil)
2960+
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{Hostname: "node2"}, nil)
29612961
require.NoError(t, err)
29622962

29632963
t.Run("CreateNode", func(t *testing.T) {

‎lib/services/local/presence.go

+28
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,34 @@ func (s *PresenceService) UpsertNode(ctx context.Context, server types.Server) (
371371
}, nil
372372
}
373373

374+
// UpdateNode conditionally updates the provided server.
375+
func (s *PresenceService) UpdateNode(ctx context.Context, server types.Server) (types.Server, error) {
376+
if server.GetNamespace() == "" {
377+
server.SetNamespace(apidefaults.Namespace)
378+
}
379+
380+
if n := server.GetNamespace(); n != apidefaults.Namespace {
381+
return nil, trace.BadParameter("cannot place node in namespace %q, custom namespaces are deprecated", n)
382+
}
383+
rev := server.GetRevision()
384+
value, err := services.MarshalServer(server)
385+
if err != nil {
386+
return nil, trace.Wrap(err)
387+
}
388+
lease, err := s.ConditionalUpdate(ctx, backend.Item{
389+
Key: backend.NewKey(nodesPrefix, server.GetNamespace(), server.GetName()),
390+
Value: value,
391+
Expires: server.Expiry(),
392+
Revision: rev,
393+
})
394+
if err != nil {
395+
return nil, trace.Wrap(err)
396+
}
397+
398+
server.SetRevision(lease.Revision)
399+
return server, nil
400+
}
401+
374402
// GetAuthServers returns a list of registered servers
375403
func (s *PresenceService) GetAuthServers() ([]types.Server, error) {
376404
return s.getServers(context.TODO(), types.KindAuthServer, authServersPrefix)

‎lib/services/local/presence_test.go

+25
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,31 @@ func TestNodeCRUD(t *testing.T) {
262262
require.NoError(t, err)
263263
})
264264

265+
t.Run("UpdateNode", func(t *testing.T) {
266+
node1, err = presence.GetNode(ctx, apidefaults.Namespace, node1.GetName())
267+
require.NoError(t, err)
268+
node1.SetAddr("1.2.3.4:8080")
269+
270+
node2, err = presence.GetNode(ctx, apidefaults.Namespace, node2.GetName())
271+
require.NoError(t, err)
272+
273+
node1, err = presence.UpdateNode(ctx, node1)
274+
require.NoError(t, err)
275+
require.Equal(t, "1.2.3.4:8080", node1.GetAddr())
276+
277+
rev := node2.GetRevision()
278+
node2.SetAddr("1.2.3.4:9090")
279+
node2.SetRevision(node1.GetRevision())
280+
281+
_, err = presence.UpdateNode(ctx, node2)
282+
require.True(t, trace.IsCompareFailed(err))
283+
node2.SetRevision(rev)
284+
285+
node2, err = presence.UpdateNode(ctx, node2)
286+
require.NoError(t, err)
287+
require.Equal(t, "1.2.3.4:9090", node2.GetAddr())
288+
})
289+
265290
// Run NodeGetters in nested subtests to allow parallelization.
266291
t.Run("NodeGetters", func(t *testing.T) {
267292
t.Run("GetNodes", func(t *testing.T) {

‎lib/services/presence.go

+1
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,5 @@ type PresenceInternal interface {
212212
UpsertHostUserInteractionTime(ctx context.Context, name string, loginTime time.Time) error
213213
GetHostUserInteractionTime(ctx context.Context, name string) (time.Time, error)
214214
UpsertReverseTunnelV2(ctx context.Context, tunnel types.ReverseTunnel) (types.ReverseTunnel, error)
215+
UpdateNode(ctx context.Context, server types.Server) (types.Server, error)
215216
}

‎lib/services/suite/suite.go

+3
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) {
488488
require.Empty(t, out)
489489

490490
srv := NewServer(types.KindNode, "srv1", "127.0.0.1:2022", apidefaults.Namespace)
491+
srv.Spec.Hostname = "llama"
491492
_, err = s.PresenceS.UpsertNode(ctx, srv)
492493
require.NoError(t, err)
493494

@@ -513,6 +514,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) {
513514
require.Empty(t, out)
514515

515516
proxy := NewServer(types.KindProxy, "proxy1", "127.0.0.1:2023", apidefaults.Namespace)
517+
proxy.Spec.Hostname = "proxy.llama"
516518
require.NoError(t, s.PresenceS.UpsertProxy(ctx, proxy))
517519

518520
out, err = s.PresenceS.GetProxies()
@@ -533,6 +535,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) {
533535
require.Empty(t, out)
534536

535537
auth := NewServer(types.KindAuthServer, "auth1", "127.0.0.1:2025", apidefaults.Namespace)
538+
auth.Spec.Hostname = "auth.llama"
536539
require.NoError(t, s.PresenceS.UpsertAuthServer(ctx, auth))
537540

538541
out, err = s.PresenceS.GetAuthServers()

‎lib/web/apiserver_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ func TestClusterNodesGet(t *testing.T) {
11501150
server1 := servers[0]
11511151

11521152
// Add another node.
1153-
server2, err := types.NewServerWithLabels("server2", types.KindNode, types.ServerSpecV2{}, map[string]string{"test-field": "test-value"})
1153+
server2, err := types.NewServerWithLabels("server2", types.KindNode, types.ServerSpecV2{Hostname: "server2"}, map[string]string{"test-field": "test-value"})
11541154
require.NoError(t, err)
11551155
_, err = env.server.Auth().UpsertNode(context.Background(), server2)
11561156
require.NoError(t, err)
@@ -1186,7 +1186,8 @@ func TestClusterNodesGet(t *testing.T) {
11861186
Kind: types.KindNode,
11871187
SubKind: types.SubKindTeleportNode,
11881188
ClusterName: clusterName,
1189-
Name: "server2",
1189+
Name: server2.GetName(),
1190+
Hostname: server2.GetHostname(),
11901191
Labels: []ui.Label{{Name: "test-field", Value: "test-value"}},
11911192
Tunnel: false,
11921193
SSHLogins: []string{pack.login},

0 commit comments

Comments
 (0)
Please sign in to comment.