From d802b56db0d8c0a6f5a59c585ab949f4b4f73cf4 Mon Sep 17 00:00:00 2001 From: Antoine Grondin Date: Tue, 15 Sep 2020 05:31:43 +0900 Subject: [PATCH] assert that GetCertificate gets called as expected --- v2/spiffetls/tlsconfig/config_test.go | 50 +++++++++++++++++++++------ 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/v2/spiffetls/tlsconfig/config_test.go b/v2/spiffetls/tlsconfig/config_test.go index 89d62ed5..978c3ae2 100644 --- a/v2/spiffetls/tlsconfig/config_test.go +++ b/v2/spiffetls/tlsconfig/config_test.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/require" ) -var LocalTrace = tlsconfig.Trace{ +var localTrace = tlsconfig.Trace{ GetCertificate: func() interface{} { fmt.Printf("got start of GetTLSCertificate\n") return nil @@ -71,7 +71,7 @@ func TestMTLSClientConfig(t *testing.T) { svid := &x509svid.SVID{} config := tlsconfig.MTLSClientConfig(svid, bundle, tlsconfig.AuthorizeAny(), - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -92,7 +92,7 @@ func TestHookMTLSClientConfig(t *testing.T) { config := createTestTLSConfig(base) tlsconfig.HookMTLSClientConfig(config, svid, bundle, tlsconfig.AuthorizeAny(), - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -111,7 +111,7 @@ func TestMTLSWebClientConfig(t *testing.T) { roots := x509.NewCertPool() config := tlsconfig.MTLSWebClientConfig(svid, roots, - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -131,7 +131,7 @@ func TestHookMTLSWebClientConfig(t *testing.T) { roots := x509.NewCertPool() tlsconfig.HookMTLSWebClientConfig(config, svid, roots, - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) // Expected AuthFields @@ -150,7 +150,7 @@ func TestTLSServerConfig(t *testing.T) { svid := &x509svid.SVID{} config := tlsconfig.TLSServerConfig(svid, - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -169,7 +169,7 @@ func TestHookTLSServerConfig(t *testing.T) { config := createTestTLSConfig(base) tlsconfig.HookTLSServerConfig(config, svid, - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -189,7 +189,7 @@ func TestMTLSServerConfig(t *testing.T) { svid := &x509svid.SVID{} config := tlsconfig.MTLSServerConfig(svid, bundle, tlsconfig.AuthorizeAny(), - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -210,7 +210,7 @@ func TestHookMTLSServerConfig(t *testing.T) { config := createTestTLSConfig(base) tlsconfig.HookMTLSServerConfig(config, svid, bundle, tlsconfig.AuthorizeAny(), - tlsconfig.WithTrace(LocalTrace), + tlsconfig.WithTrace(localTrace), ) assert.Nil(t, config.Certificates) @@ -261,6 +261,22 @@ func TestHookMTLSWebServerConfig(t *testing.T) { assertUnrelatedFieldsUntouched(t, base, config) } +func hookedTracer(onGetCertificate, onGotCertificate func()) tlsconfig.Trace { + return tlsconfig.Trace{ + GetCertificate: func() interface{} { + if onGetCertificate != nil { + onGetCertificate() + } + return nil + }, + GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) { + if onGotCertificate != nil { + onGotCertificate() + } + }, + } +} + func TestGetCertificate(t *testing.T) { testCases := []struct { name string @@ -293,7 +309,12 @@ func TestGetCertificate(t *testing.T) { for _, testCase := range testCases { testCase := testCase t.Run(testCase.name, func(t *testing.T) { - getCertificate := tlsconfig.GetCertificate(testCase.source, tlsconfig.WithTrace(LocalTrace)) + getCertificateCalls := 0 + tracer := hookedTracer( + func() { getCertificateCalls++ }, + nil, + ) + getCertificate := tlsconfig.GetCertificate(testCase.source, tlsconfig.WithTrace(tracer)) require.NotNil(t, getCertificate) tlsCert, err := getCertificate(&tls.ClientHelloInfo{}) @@ -305,6 +326,7 @@ func TestGetCertificate(t *testing.T) { require.NoError(t, err) require.Equal(t, testCase.expectedCerts, tlsCert.Certificate) + require.Equal(t, 1, getCertificateCalls) }) } } @@ -341,7 +363,12 @@ func TestGetClientCertificate(t *testing.T) { for _, testCase := range testCases { testCase := testCase t.Run(testCase.name, func(t *testing.T) { - getClientCertificate := tlsconfig.GetClientCertificate(testCase.source, tlsconfig.WithTrace(LocalTrace)) + getCertificateCalls := 0 + tracer := hookedTracer( + func() { getCertificateCalls++ }, + nil, + ) + getClientCertificate := tlsconfig.GetClientCertificate(testCase.source, tlsconfig.WithTrace(tracer)) require.NotNil(t, getClientCertificate) tlsCert, err := getClientCertificate(&tls.CertificateRequestInfo{}) @@ -353,6 +380,7 @@ func TestGetClientCertificate(t *testing.T) { require.NoError(t, err) require.Equal(t, testCase.expectedCerts, tlsCert.Certificate) + require.Equal(t, 1, getCertificateCalls) }) } }