Skip to content

Commit

Permalink
Support proxy in case kms-plugin can not access key vault (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
bingosummer authored Mar 10, 2022
1 parent 14cd0a7 commit 14bc5c4
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 39 deletions.
6 changes: 5 additions & 1 deletion cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ var (
healthzTimeout = flag.Duration("healthz-timeout", 20*time.Second, "RPC timeout for health check")
metricsBackend = flag.String("metrics-backend", "prometheus", "Backend used for metrics")
metricsAddress = flag.String("metrics-addr", "8095", "The address the metric endpoint binds to")

proxyMode = flag.Bool("proxy-mode", false, "Proxy mode")
proxyAddress = flag.String("proxy-address", "", "proxy address")
proxyPort = flag.Int("proxy-port", 7788, "port for proxy")
)

func main() {
Expand All @@ -68,7 +72,7 @@ func main() {
}

klog.InfoS("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate)
kmsServer, err := plugin.New(ctx, *configFilePath, *keyvaultName, *keyName, *keyVersion)
kmsServer, err := plugin.New(ctx, *configFilePath, *keyvaultName, *keyName, *keyVersion, *proxyMode, *proxyAddress, *proxyPort)
if err != nil {
klog.Fatalf("failed to create server, error: %v", err)
}
Expand Down
40 changes: 35 additions & 5 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"crypto/rsa"
"crypto/x509"
"fmt"
"net/http"
"os"
"regexp"
"strings"

"github.com/Azure/kubernetes-kms/pkg/config"
"github.com/Azure/kubernetes-kms/pkg/consts"

"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
Expand All @@ -23,9 +25,9 @@ import (
)

// GetKeyvaultToken() returns token for Keyvault endpoint
func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment) (authorizer autorest.Authorizer, err error) {
func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment, proxyMode bool) (authorizer autorest.Authorizer, err error) {
kvEndPoint := strings.TrimSuffix(env.KeyVaultEndpoint, "/")
servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, kvEndPoint)
servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, kvEndPoint, proxyMode)
if err != nil {
return nil, err
}
Expand All @@ -34,7 +36,7 @@ func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment) (autho
}

// GetServicePrincipalToken creates a new service principal token based on the configuration
func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string) (adal.OAuthTokenProvider, error) {
func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string, proxyMode bool) (adal.OAuthTokenProvider, error) {
oauthConfig, err := adal.NewOAuthConfig(aadEndpoint, config.TenantID)
if err != nil {
return nil, fmt.Errorf("failed to create OAuth config, error: %v", err)
Expand Down Expand Up @@ -64,11 +66,18 @@ func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource
klog.V(2).InfoS("azure: using client_id+client_secret to retrieve access token",
"clientID", redactClientCredentials(config.ClientID), "clientSecret", redactClientCredentials(config.ClientSecret))

return adal.NewServicePrincipalToken(
spt, err := adal.NewServicePrincipalToken(
*oauthConfig,
config.ClientID,
config.ClientSecret,
resource)
if err != nil {
return nil, err
}
if proxyMode {
return addTargetTypeHeader(spt), nil
}
return spt, nil
}

if len(config.AADClientCertPath) > 0 && len(config.AADClientCertPassword) > 0 {
Expand All @@ -81,12 +90,19 @@ func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource
if err != nil {
return nil, fmt.Errorf("failed to decode the client certificate, error: %v", err)
}
return adal.NewServicePrincipalTokenFromCertificate(
spt, err := adal.NewServicePrincipalTokenFromCertificate(
*oauthConfig,
config.ClientID,
certificate,
privateKey,
resource)
if err != nil {
return nil, err
}
if proxyMode {
return addTargetTypeHeader(spt), nil
}
return spt, nil
}

return nil, fmt.Errorf("no credentials provided for accessing keyvault")
Expand Down Expand Up @@ -124,3 +140,17 @@ func redactClientCredentials(sensitiveString string) string {
r, _ := regexp.Compile(`^(\S{4})(\S|\s)*(\S{4})$`)
return r.ReplaceAllString(sensitiveString, "$1##### REDACTED #####$3")
}

// addTargetTypeHeader adds the target header if proxy mode is enabled
func addTargetTypeHeader(spt *adal.ServicePrincipalToken) *adal.ServicePrincipalToken {
spt.SetSender(autorest.CreateSender(
(func() autorest.SendDecorator {
return func(s autorest.Sender) autorest.Sender {
return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) {
r.Header.Set(consts.RequestHeaderTargetType, consts.TargetTypeAzureActiveDirectory)
return s.Do(r)
})
}
})()))
return spt
}
22 changes: 14 additions & 8 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ func TestRedactClientCredentials(t *testing.T) {

func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
tests := []struct {
name string
config *config.AzureConfig
name string
config *config.AzureConfig
proxyMode bool // The proxy mode doesn't matter if user-assigned managed identity is used to get service principal token
}{
{
name: "using user-assigned managed identity to access keyvault",
Expand All @@ -73,6 +74,7 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
ClientID: "AADClientID",
ClientSecret: "AADClientSecret",
},
proxyMode: false,
},
// The Azure service principal is ignored when
// UseManagedIdentityExtension is set to true
Expand All @@ -82,12 +84,13 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
UseManagedIdentityExtension: true,
UserAssignedIdentityID: "clientID",
},
proxyMode: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net")
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", test.proxyMode)
if err != nil {
t.Fatalf("expected err to be nil, got: %v", err)
}
Expand All @@ -108,14 +111,16 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {

func TestGetServicePrincipalTokenFromMSI(t *testing.T) {
tests := []struct {
name string
config *config.AzureConfig
name string
config *config.AzureConfig
proxyMode bool // The proxy mode doesn't matter if MSI is used to get service principal token
}{
{
name: "using system-assigned managed identity to access keyvault",
config: &config.AzureConfig{
UseManagedIdentityExtension: true,
},
proxyMode: false,
},
// The Azure service principal is ignored when
// UseManagedIdentityExtension is set to true
Expand All @@ -127,12 +132,13 @@ func TestGetServicePrincipalTokenFromMSI(t *testing.T) {
ClientID: "AADClientID",
ClientSecret: "AADClientSecret",
},
proxyMode: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net")
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", test.proxyMode)
if err != nil {
t.Fatalf("expected err to be nil, got: %v", err)
}
Expand Down Expand Up @@ -168,7 +174,7 @@ func TestGetServicePrincipalToken(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net")
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", false)
if err != nil {
t.Fatalf("expected err to be nil, got: %v", err)
}
Expand All @@ -183,7 +189,7 @@ func TestGetServicePrincipalToken(t *testing.T) {
t.Fatalf("expected err to be nil, got: %v", err)
}
if !reflect.DeepEqual(token, spt) {
t.Fatalf("expected: %v, got: %v", spt, token)
t.Fatalf("expected: %+v, got: %+v", spt, token)
}
})
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/consts/consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft and contributors. All rights reserved.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

package consts

const (
// In proxy mode, the header is added into the requests from kms-plugin.
// The proxy will check the header and forward the request to different destinations.
// e.g. When the value of the header "x-azure-proxy-target" is "KeyVault", the request
// is forwared to Azure Key Vault by the proxy.
RequestHeaderTargetType = "x-azure-proxy-target"
TargetTypeAzureActiveDirectory = "AzureActiveDirectory"
TargetTypeKeyVault = "KeyVault"
)
24 changes: 21 additions & 3 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"encoding/base64"
"fmt"
"regexp"
"strings"

"github.com/Azure/kubernetes-kms/pkg/auth"
"github.com/Azure/kubernetes-kms/pkg/config"
"github.com/Azure/kubernetes-kms/pkg/consts"
"github.com/Azure/kubernetes-kms/pkg/utils"
"github.com/Azure/kubernetes-kms/pkg/version"

kv "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"k8s.io/klog/v2"
)
Expand All @@ -38,7 +41,7 @@ type keyVaultClient struct {
}

// NewKeyVaultClient returns a new key vault client to use for kms operations
func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersion string) (*keyVaultClient, error) {
func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersion string, proxyMode bool, proxyAddress string, proxyPort int) (*keyVaultClient, error) {
// Sanitize vaultName, keyName, keyVersion. (https://github.com/Azure/kubernetes-kms/issues/85)
vaultName = utils.SanitizeString(vaultName)
keyName = utils.SanitizeString(keyName)
Expand All @@ -58,7 +61,11 @@ func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersio
if err != nil {
return nil, fmt.Errorf("failed to parse cloud environment: %s, error: %+v", config.Cloud, err)
}
token, err := auth.GetKeyvaultToken(config, env)
if proxyMode {
env.ActiveDirectoryEndpoint = fmt.Sprintf("http://%s:%d/", proxyAddress, proxyPort)
}

token, err := auth.GetKeyvaultToken(config, env, proxyMode)
if err != nil {
return nil, fmt.Errorf("failed to get key vault token, error: %+v", err)
}
Expand All @@ -69,7 +76,12 @@ func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersio
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
}

klog.InfoS("using kms key for encrypt/decrypt", "vaultName", vaultName, "keyName", keyName, "keyVersion", keyVersion)
if proxyMode {
kvClient.RequestInspector = autorest.WithHeader(consts.RequestHeaderTargetType, consts.TargetTypeKeyVault)
vaultURL = getProxiedVaultURL(vaultURL, proxyAddress, proxyPort)
}

klog.InfoS("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion)

client := &keyVaultClient{
baseClient: kvClient,
Expand Down Expand Up @@ -130,5 +142,11 @@ func getVaultURL(vaultName string, azureEnvironment *azure.Environment) (vaultUR

vaultDNSSuffixValue := azureEnvironment.KeyVaultDNSSuffix
vaultURI := "https://" + vaultName + "." + vaultDNSSuffixValue + "/"

return &vaultURI, nil
}

func getProxiedVaultURL(vaultURL *string, proxyAddress string, proxyPort int) *string {
proxiedVaultURL := fmt.Sprintf("http://%s:%d/%s", proxyAddress, proxyPort, strings.TrimPrefix(*vaultURL, "https://"))
return &proxiedVaultURL
}
Loading

0 comments on commit 14bc5c4

Please sign in to comment.