19
19
package auth
20
20
21
21
import (
22
+ "cmp"
22
23
"context"
23
24
"crypto/rand"
24
25
"crypto/x509"
@@ -34,7 +35,7 @@ import (
34
35
"testing"
35
36
"time"
36
37
37
- "github.com/google/go-cmp/cmp"
38
+ gocmp "github.com/google/go-cmp/cmp"
38
39
"github.com/google/go-cmp/cmp/cmpopts"
39
40
"github.com/google/uuid"
40
41
"github.com/gravitational/license"
@@ -307,7 +308,7 @@ func TestSessions(t *testing.T) {
307
308
require .NoError (t , err )
308
309
assert .Empty (t , out .GetSSHPriv ())
309
310
assert .Empty (t , out .GetTLSPriv ())
310
- assert .Empty (t , cmp .Diff (ws , out ,
311
+ assert .Empty (t , gocmp .Diff (ws , out ,
311
312
cmpopts .IgnoreFields (types.Metadata {}, "Revision" ),
312
313
cmpopts .IgnoreFields (types.WebSessionSpecV2 {}, "Priv" , "TLSPriv" )))
313
314
@@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) {
1655
1656
AssetTag : test .opts .DeviceExtensions .AssetTag ,
1656
1657
CredentialId : test .opts .DeviceExtensions .CredentialID ,
1657
1658
}
1658
- if diff := cmp .Diff (want , got ); diff != "" {
1659
+ if diff := gocmp .Diff (want , got ); diff != "" {
1659
1660
t .Errorf ("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n %s" , diff )
1660
1661
}
1661
1662
}
@@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) {
2301
2302
// Assert TLS extensions.
2302
2303
_ , newIdentity := parseX509PEMAndIdentity (t , newSession .GetTLSCert ())
2303
2304
wantExts := tlsca .DeviceExtensions (* deviceExts )
2304
- if diff := cmp .Diff (wantExts , newIdentity .DeviceExtensions ); diff != "" {
2305
+ if diff := gocmp .Diff (wantExts , newIdentity .DeviceExtensions ); diff != "" {
2305
2306
t .Errorf ("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n %s" , diff )
2306
2307
}
2307
2308
2308
2309
// Assert SSH extensions.
2309
- if diff := cmp .Diff (wantExts , parseSSHDeviceExtensions (t , newSession .GetPub ())); diff != "" {
2310
+ if diff := gocmp .Diff (wantExts , parseSSHDeviceExtensions (t , newSession .GetPub ())); diff != "" {
2310
2311
t .Errorf ("newSession.Pub DeviceExtensions mismatch (-want +got)\n %s" , diff )
2311
2312
}
2312
2313
})
@@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) {
2545
2546
// Validate audit event.
2546
2547
lastEvent := p .mockEmitter .LastEvent ()
2547
2548
require .IsType (t , & apievents.CertificateCreate {}, lastEvent )
2548
- require .Empty (t , cmp .Diff (
2549
+ require .Empty (t , gocmp .Diff (
2549
2550
& apievents.CertificateCreate {
2550
2551
Metadata : apievents.Metadata {
2551
2552
Type : events .CertificateCreateEvent ,
@@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD
3801
3802
}
3802
3803
3803
3804
// Ignore LastUsed and SignatureCounter?
3804
- var opts []cmp .Option
3805
+ var opts []gocmp .Option
3805
3806
if ignoreUpdateAndCounter {
3806
- opts = append (opts , cmp .FilterPath (func (path cmp .Path ) bool {
3807
+ opts = append (opts , gocmp .FilterPath (func (path gocmp .Path ) bool {
3807
3808
p := path .String ()
3808
3809
return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter"
3809
- }, cmp .Ignore ()))
3810
+ }, gocmp .Ignore ()))
3810
3811
}
3811
3812
3812
- if diff := cmp .Diff (want , got , opts ... ); diff != "" {
3813
+ if diff := gocmp .Diff (want , got , opts ... ); diff != "" {
3813
3814
t .Errorf ("compareDevices mismatch (-want +got):\n %s" , diff )
3814
3815
}
3815
3816
}
@@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest
4444
4445
4445
4446
return & notification
4446
4447
}
4448
+
4449
+ // TestServerHostnameSanitization tests that persisting servers with
4450
+ // "invalid" hostnames results in the hostname being sanitized and the
4451
+ // illegal name being placed in a label.
4452
+ func TestServerHostnameSanitization (t * testing.T ) {
4453
+ t .Parallel ()
4454
+ ctx := context .Background ()
4455
+ srv , err := NewTestAuthServer (TestAuthServerConfig {Dir : t .TempDir ()})
4456
+ require .NoError (t , err )
4457
+
4458
+ cases := []struct {
4459
+ name string
4460
+ hostname string
4461
+ addr string
4462
+ invalidHostname bool
4463
+ invalidAddr bool
4464
+ }{
4465
+ {
4466
+ name : "valid dns hostname" ,
4467
+ hostname : "llama.example.com" ,
4468
+ },
4469
+ {
4470
+ name : "valid friendly hostname" ,
4471
+ hostname : "llama" ,
4472
+ },
4473
+ {
4474
+ name : "uuid hostname" ,
4475
+ hostname : uuid .NewString (),
4476
+ },
4477
+ {
4478
+ name : "uuid dns hostname" ,
4479
+ hostname : uuid .NewString () + ".example.com" ,
4480
+ },
4481
+ {
4482
+ name : "empty hostname" ,
4483
+ hostname : "" ,
4484
+ invalidHostname : true ,
4485
+ },
4486
+ {
4487
+ name : "exceptionally long hostname" ,
4488
+ hostname : strings .Repeat ("a" , serverHostnameMaxLen * 2 ),
4489
+ invalidHostname : true ,
4490
+ },
4491
+ {
4492
+ name : "invalid dns hostname" ,
4493
+ hostname : "llama..example.com" ,
4494
+ invalidHostname : true ,
4495
+ },
4496
+ {
4497
+ name : "spaces in hostname" ,
4498
+ hostname : "the quick brown fox jumps over the lazy dog" ,
4499
+ invalidHostname : true ,
4500
+ },
4501
+ {
4502
+ name : "invalid addr" ,
4503
+ hostname : ".." ,
4504
+ addr : "..:2345" ,
4505
+ invalidHostname : true ,
4506
+ invalidAddr : true ,
4507
+ },
4508
+ }
4509
+
4510
+ for _ , test := range cases {
4511
+ t .Run (test .name , func (t * testing.T ) {
4512
+ for _ , subKind := range []string {types .KindNode , types .SubKindOpenSSHNode } {
4513
+ t .Run (subKind , func (t * testing.T ) {
4514
+ server := & types.ServerV2 {
4515
+ Kind : types .KindNode ,
4516
+ SubKind : subKind ,
4517
+ Metadata : types.Metadata {
4518
+ Name : uuid .NewString (),
4519
+ },
4520
+ Spec : types.ServerSpecV2 {
4521
+ Hostname : test .hostname ,
4522
+ Addr : cmp .Or (test .addr , "abcd:1234" ),
4523
+ },
4524
+ }
4525
+ if subKind == types .KindNode {
4526
+ server .SubKind = ""
4527
+ }
4528
+
4529
+ _ , err = srv .AuthServer .UpsertNode (ctx , server )
4530
+ require .NoError (t , err )
4531
+
4532
+ replacedValue , _ := server .GetLabel ("teleport.internal/invalid-hostname" )
4533
+ if ! test .invalidHostname {
4534
+ assert .Equal (t , test .hostname , server .GetHostname ())
4535
+ assert .Empty (t , replacedValue )
4536
+ return
4537
+ }
4538
+
4539
+ assert .Equal (t , test .hostname , replacedValue )
4540
+ switch subKind {
4541
+ case types .SubKindOpenSSHNode :
4542
+ host , _ , err := net .SplitHostPort (server .GetAddr ())
4543
+ assert .NoError (t , err )
4544
+ if ! test .invalidAddr {
4545
+ // If the address is valid, then the hostname should be set
4546
+ // to the host of the addr field.
4547
+ assert .Equal (t , host , server .GetHostname ())
4548
+ } else {
4549
+ // If the address is not valid, then the hostname should be
4550
+ // set to a UUID.
4551
+ assert .NotEqual (t , host , server .GetHostname ())
4552
+ assert .NotEqual (t , server .GetName (), server .GetHostname ())
4553
+
4554
+ _ , err := uuid .Parse (server .GetHostname ())
4555
+ require .NoError (t , err )
4556
+ }
4557
+ default :
4558
+ assert .Equal (t , server .GetName (), server .GetHostname ())
4559
+ }
4560
+ })
4561
+ }
4562
+ })
4563
+ }
4564
+ }
0 commit comments