From 14bc5c4533f2fac0133a6d81e80ac647cee29adf Mon Sep 17 00:00:00 2001 From: Bin Xia Date: Thu, 10 Mar 2022 09:16:26 +0800 Subject: [PATCH] Support proxy in case kms-plugin can not access key vault (#119) --- cmd/server/main.go | 6 ++- pkg/auth/auth.go | 40 ++++++++++++++++--- pkg/auth/auth_test.go | 22 +++++++---- pkg/consts/consts.go | 16 ++++++++ pkg/plugin/keyvault.go | 24 ++++++++++-- pkg/plugin/keyvault_test.go | 78 +++++++++++++++++++++++++++---------- pkg/plugin/server.go | 4 +- 7 files changed, 151 insertions(+), 39 deletions(-) create mode 100644 pkg/consts/consts.go diff --git a/cmd/server/main.go b/cmd/server/main.go index cab46337..ced7b32d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -43,6 +43,10 @@ var ( healthzTimeout = flag.Duration("healthz-timeout", 20*time.Second, "RPC timeout for health check") metricsBackend = flag.String("metrics-backend", "prometheus", "Backend used for metrics") metricsAddress = flag.String("metrics-addr", "8095", "The address the metric endpoint binds to") + + proxyMode = flag.Bool("proxy-mode", false, "Proxy mode") + proxyAddress = flag.String("proxy-address", "", "proxy address") + proxyPort = flag.Int("proxy-port", 7788, "port for proxy") ) func main() { @@ -68,7 +72,7 @@ func main() { } klog.InfoS("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate) - kmsServer, err := plugin.New(ctx, *configFilePath, *keyvaultName, *keyName, *keyVersion) + kmsServer, err := plugin.New(ctx, *configFilePath, *keyvaultName, *keyName, *keyVersion, *proxyMode, *proxyAddress, *proxyPort) if err != nil { klog.Fatalf("failed to create server, error: %v", err) } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index e5c2eff4..b26e5312 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -9,11 +9,13 @@ import ( "crypto/rsa" "crypto/x509" "fmt" + "net/http" "os" "regexp" "strings" "github.com/Azure/kubernetes-kms/pkg/config" + "github.com/Azure/kubernetes-kms/pkg/consts" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/adal" @@ -23,9 +25,9 @@ import ( ) // GetKeyvaultToken() returns token for Keyvault endpoint -func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment) (authorizer autorest.Authorizer, err error) { +func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment, proxyMode bool) (authorizer autorest.Authorizer, err error) { kvEndPoint := strings.TrimSuffix(env.KeyVaultEndpoint, "/") - servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, kvEndPoint) + servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, kvEndPoint, proxyMode) if err != nil { return nil, err } @@ -34,7 +36,7 @@ func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment) (autho } // GetServicePrincipalToken creates a new service principal token based on the configuration -func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string) (adal.OAuthTokenProvider, error) { +func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string, proxyMode bool) (adal.OAuthTokenProvider, error) { oauthConfig, err := adal.NewOAuthConfig(aadEndpoint, config.TenantID) if err != nil { return nil, fmt.Errorf("failed to create OAuth config, error: %v", err) @@ -64,11 +66,18 @@ func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource klog.V(2).InfoS("azure: using client_id+client_secret to retrieve access token", "clientID", redactClientCredentials(config.ClientID), "clientSecret", redactClientCredentials(config.ClientSecret)) - return adal.NewServicePrincipalToken( + spt, err := adal.NewServicePrincipalToken( *oauthConfig, config.ClientID, config.ClientSecret, resource) + if err != nil { + return nil, err + } + if proxyMode { + return addTargetTypeHeader(spt), nil + } + return spt, nil } if len(config.AADClientCertPath) > 0 && len(config.AADClientCertPassword) > 0 { @@ -81,12 +90,19 @@ func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource if err != nil { return nil, fmt.Errorf("failed to decode the client certificate, error: %v", err) } - return adal.NewServicePrincipalTokenFromCertificate( + spt, err := adal.NewServicePrincipalTokenFromCertificate( *oauthConfig, config.ClientID, certificate, privateKey, resource) + if err != nil { + return nil, err + } + if proxyMode { + return addTargetTypeHeader(spt), nil + } + return spt, nil } return nil, fmt.Errorf("no credentials provided for accessing keyvault") @@ -124,3 +140,17 @@ func redactClientCredentials(sensitiveString string) string { r, _ := regexp.Compile(`^(\S{4})(\S|\s)*(\S{4})$`) return r.ReplaceAllString(sensitiveString, "$1##### REDACTED #####$3") } + +// addTargetTypeHeader adds the target header if proxy mode is enabled +func addTargetTypeHeader(spt *adal.ServicePrincipalToken) *adal.ServicePrincipalToken { + spt.SetSender(autorest.CreateSender( + (func() autorest.SendDecorator { + return func(s autorest.Sender) autorest.Sender { + return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) { + r.Header.Set(consts.RequestHeaderTargetType, consts.TargetTypeAzureActiveDirectory) + return s.Do(r) + }) + } + })())) + return spt +} diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index f14defd6..c2d56f47 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -61,8 +61,9 @@ func TestRedactClientCredentials(t *testing.T) { func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { tests := []struct { - name string - config *config.AzureConfig + name string + config *config.AzureConfig + proxyMode bool // The proxy mode doesn't matter if user-assigned managed identity is used to get service principal token }{ { name: "using user-assigned managed identity to access keyvault", @@ -73,6 +74,7 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { ClientID: "AADClientID", ClientSecret: "AADClientSecret", }, + proxyMode: false, }, // The Azure service principal is ignored when // UseManagedIdentityExtension is set to true @@ -82,12 +84,13 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { UseManagedIdentityExtension: true, UserAssignedIdentityID: "clientID", }, + proxyMode: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net") + token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", test.proxyMode) if err != nil { t.Fatalf("expected err to be nil, got: %v", err) } @@ -108,14 +111,16 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { func TestGetServicePrincipalTokenFromMSI(t *testing.T) { tests := []struct { - name string - config *config.AzureConfig + name string + config *config.AzureConfig + proxyMode bool // The proxy mode doesn't matter if MSI is used to get service principal token }{ { name: "using system-assigned managed identity to access keyvault", config: &config.AzureConfig{ UseManagedIdentityExtension: true, }, + proxyMode: false, }, // The Azure service principal is ignored when // UseManagedIdentityExtension is set to true @@ -127,12 +132,13 @@ func TestGetServicePrincipalTokenFromMSI(t *testing.T) { ClientID: "AADClientID", ClientSecret: "AADClientSecret", }, + proxyMode: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net") + token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", test.proxyMode) if err != nil { t.Fatalf("expected err to be nil, got: %v", err) } @@ -168,7 +174,7 @@ func TestGetServicePrincipalToken(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net") + token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", false) if err != nil { t.Fatalf("expected err to be nil, got: %v", err) } @@ -183,7 +189,7 @@ func TestGetServicePrincipalToken(t *testing.T) { t.Fatalf("expected err to be nil, got: %v", err) } if !reflect.DeepEqual(token, spt) { - t.Fatalf("expected: %v, got: %v", spt, token) + t.Fatalf("expected: %+v, got: %+v", spt, token) } }) } diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go new file mode 100644 index 00000000..2f12d730 --- /dev/null +++ b/pkg/consts/consts.go @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft and contributors. All rights reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +package consts + +const ( + // In proxy mode, the header is added into the requests from kms-plugin. + // The proxy will check the header and forward the request to different destinations. + // e.g. When the value of the header "x-azure-proxy-target" is "KeyVault", the request + // is forwared to Azure Key Vault by the proxy. + RequestHeaderTargetType = "x-azure-proxy-target" + TargetTypeAzureActiveDirectory = "AzureActiveDirectory" + TargetTypeKeyVault = "KeyVault" +) diff --git a/pkg/plugin/keyvault.go b/pkg/plugin/keyvault.go index 0f565ccb..0573cc46 100644 --- a/pkg/plugin/keyvault.go +++ b/pkg/plugin/keyvault.go @@ -10,13 +10,16 @@ import ( "encoding/base64" "fmt" "regexp" + "strings" "github.com/Azure/kubernetes-kms/pkg/auth" "github.com/Azure/kubernetes-kms/pkg/config" + "github.com/Azure/kubernetes-kms/pkg/consts" "github.com/Azure/kubernetes-kms/pkg/utils" "github.com/Azure/kubernetes-kms/pkg/version" kv "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault" + "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure" "k8s.io/klog/v2" ) @@ -38,7 +41,7 @@ type keyVaultClient struct { } // NewKeyVaultClient returns a new key vault client to use for kms operations -func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersion string) (*keyVaultClient, error) { +func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersion string, proxyMode bool, proxyAddress string, proxyPort int) (*keyVaultClient, error) { // Sanitize vaultName, keyName, keyVersion. (https://github.com/Azure/kubernetes-kms/issues/85) vaultName = utils.SanitizeString(vaultName) keyName = utils.SanitizeString(keyName) @@ -58,7 +61,11 @@ func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersio if err != nil { return nil, fmt.Errorf("failed to parse cloud environment: %s, error: %+v", config.Cloud, err) } - token, err := auth.GetKeyvaultToken(config, env) + if proxyMode { + env.ActiveDirectoryEndpoint = fmt.Sprintf("http://%s:%d/", proxyAddress, proxyPort) + } + + token, err := auth.GetKeyvaultToken(config, env, proxyMode) if err != nil { return nil, fmt.Errorf("failed to get key vault token, error: %+v", err) } @@ -69,7 +76,12 @@ func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersio return nil, fmt.Errorf("failed to get vault url, error: %+v", err) } - klog.InfoS("using kms key for encrypt/decrypt", "vaultName", vaultName, "keyName", keyName, "keyVersion", keyVersion) + if proxyMode { + kvClient.RequestInspector = autorest.WithHeader(consts.RequestHeaderTargetType, consts.TargetTypeKeyVault) + vaultURL = getProxiedVaultURL(vaultURL, proxyAddress, proxyPort) + } + + klog.InfoS("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion) client := &keyVaultClient{ baseClient: kvClient, @@ -130,5 +142,11 @@ func getVaultURL(vaultName string, azureEnvironment *azure.Environment) (vaultUR vaultDNSSuffixValue := azureEnvironment.KeyVaultDNSSuffix vaultURI := "https://" + vaultName + "." + vaultDNSSuffixValue + "/" + return &vaultURI, nil } + +func getProxiedVaultURL(vaultURL *string, proxyAddress string, proxyPort int) *string { + proxiedVaultURL := fmt.Sprintf("http://%s:%d/%s", proxyAddress, proxyPort, strings.TrimPrefix(*vaultURL, "https://")) + return &proxiedVaultURL +} diff --git a/pkg/plugin/keyvault_test.go b/pkg/plugin/keyvault_test.go index 3e41f031..6b12ee93 100644 --- a/pkg/plugin/keyvault_test.go +++ b/pkg/plugin/keyvault_test.go @@ -9,29 +9,39 @@ import ( "strings" "testing" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/kubernetes-kms/pkg/auth" "github.com/Azure/kubernetes-kms/pkg/config" + "github.com/Azure/kubernetes-kms/pkg/utils" ) func TestNewKeyVaultClient(t *testing.T) { + // nolint: maligned tests := []struct { - desc string - config *config.AzureConfig - vaultName string - keyName string - keyVersion string - vaultSKU string - expectedErr bool + desc string + config *config.AzureConfig + vaultName string + keyName string + keyVersion string + proxyMode bool + proxyAddress string + proxyPort int + vaultSKU string + expectedErr bool + expectedVaultURL string }{ { desc: "vault name not provided", config: &config.AzureConfig{}, + proxyMode: false, expectedErr: true, }, { desc: "key name not provided", config: &config.AzureConfig{}, vaultName: "testkv", + proxyMode: false, expectedErr: true, }, { @@ -39,6 +49,7 @@ func TestNewKeyVaultClient(t *testing.T) { config: &config.AzureConfig{}, vaultName: "testkv", keyName: "k8s", + proxyMode: false, expectedErr: true, }, { @@ -50,26 +61,42 @@ func TestNewKeyVaultClient(t *testing.T) { expectedErr: true, }, { - desc: "no error", - config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"}, - vaultName: "testkv", - keyName: "key1", - keyVersion: "262067a9e8ba401aa8a746c5f1a7e147", - expectedErr: false, + desc: "no error", + config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"}, + vaultName: "testkv", + keyName: "key1", + keyVersion: "262067a9e8ba401aa8a746c5f1a7e147", + proxyMode: false, + expectedErr: false, + expectedVaultURL: "https://testkv.vault.azure.net/", }, { - desc: "no error with double quotes", - config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"}, - vaultName: "\"testkv\"", - keyName: "\"key1\"", - keyVersion: "\"262067a9e8ba401aa8a746c5f1a7e147\"", - expectedErr: false, + desc: "no error with double quotes", + config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"}, + vaultName: "\"testkv\"", + keyName: "\"key1\"", + keyVersion: "\"262067a9e8ba401aa8a746c5f1a7e147\"", + proxyMode: false, + expectedErr: false, + expectedVaultURL: "https://testkv.vault.azure.net/", + }, + { + desc: "no error with proxy mode", + config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"}, + vaultName: "testkv", + keyName: "key1", + keyVersion: "262067a9e8ba401aa8a746c5f1a7e147", + proxyMode: true, + proxyAddress: "localhost", + proxyPort: 7788, + expectedErr: false, + expectedVaultURL: "http://localhost:7788/testkv.vault.azure.net/", }, } for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - kvClient, err := newKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion) + kvClient, err := newKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort) if test.expectedErr && err == nil || !test.expectedErr && err != nil { t.Fatalf("expected error: %v, got error: %v", test.expectedErr, err) } @@ -80,6 +107,17 @@ func TestNewKeyVaultClient(t *testing.T) { if !strings.Contains(kvClient.baseClient.UserAgent, "k8s-kms-keyvault") { t.Fatalf("expected k8s-kms-keyvault user agent") } + + vaultURL, err := getVaultURL(utils.SanitizeString(test.vaultName), &azure.PublicCloud) + if err != nil { + t.Fatalf("expected no error of getting vault URL, got error: %v", err) + } + if test.proxyMode { + vaultURL = getProxiedVaultURL(vaultURL, test.proxyAddress, test.proxyPort) + } + if *vaultURL != test.expectedVaultURL { + t.Fatalf("expected vault URL: %v, got vault URL: %v", test.expectedVaultURL, *vaultURL) + } } }) } diff --git a/pkg/plugin/server.go b/pkg/plugin/server.go index 7287b1bd..b7957d4c 100644 --- a/pkg/plugin/server.go +++ b/pkg/plugin/server.go @@ -24,12 +24,12 @@ type KeyManagementServiceServer struct { } // New creates an instance of the KMS Service Server. -func New(ctx context.Context, configFilePath, vaultName, keyName, keyVersion string) (*KeyManagementServiceServer, error) { +func New(ctx context.Context, configFilePath, vaultName, keyName, keyVersion string, proxyMode bool, proxyAddress string, proxyPort int) (*KeyManagementServiceServer, error) { cfg, err := config.GetAzureConfig(configFilePath) if err != nil { return nil, err } - kvClient, err := newKeyVaultClient(cfg, vaultName, keyName, keyVersion) + kvClient, err := newKeyVaultClient(cfg, vaultName, keyName, keyVersion, proxyMode, proxyAddress, proxyPort) if err != nil { return nil, err }