Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure VM MSI support #584

Merged
merged 7 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cmd/kola/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func init() {
sv(&kola.AzureOptions.ResourceGroup, "azure-resource-group", "", "Deploy resources in an existing resource group")
sv(&kola.AzureOptions.AvailabilitySet, "azure-availability-set", "", "Deploy instances with an existing availibity set")
sv(&kola.AzureOptions.KolaVnet, "azure-kola-vnet", "", "Pass the vnet/subnet that kola is being ran from to restrict network access to created storage accounts")
sv(&kola.AzureOptions.VMIdentity, "azure-vm-identity", "", "Assign a managed identity to the VM by name (will be looked up for its ID)")

// do-specific options
sv(&kola.DOOptions.ConfigPath, "do-config-file", "", "DigitalOcean config file (default \"~/"+auth.DOConfigPath+"\")")
Expand Down Expand Up @@ -416,5 +417,15 @@ func GetSSHKeys(sshKeys []string) ([]agent.Key, error) {
allKeys = append(allKeys, key)
}

// Ignition v3 does not allow duplicate keys so we need to deduplicate
allUniqueKeys := make(map[string]*agent.Key)
for _, key := range allKeys {
allUniqueKeys[string(key.Blob)] = &key
}
allKeys = []agent.Key{}
for _, value := range allUniqueKeys {
allKeys = append(allKeys, *value)
}

return allKeys, nil
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5 v5.7.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v5 v5.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.0.0 h1:Kb8e
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.0.0/go.mod h1:lYq15QkJyEsNegz5EhI/0SXQ6spvGfgwBH/Qyzkoc/s=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0 h1:pPvTJ1dY0sA35JOeFq6TsY2xj6Z85Yo23Pj4wCCvu4o=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0/go.mod h1:mLfWfj8v3jfWKsL9G4eoBoXVcsqcIUTapmdKy7uGOp0=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 h1:z4YeiSXxnUI+PqB46Yj6MZA3nwb1CcJIkEMDrzUd8Cs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0/go.mod h1:rko9SzMxcMk0NJsNAxALEGaTYyy79bNRwxgJfrH0Spw=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v5 v5.2.0 h1:qBlqTo40ARdI7Pmq+enBiTnejZk2BF+PHgktgG8k3r8=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v5 v5.2.0/go.mod h1:UmyOatRyQodVpp55Jr5WJmnkmVW4wKfo85uHFmMEjfM=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM=
Expand Down
45 changes: 45 additions & 0 deletions platform/api/azure/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v5"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions"
Expand Down Expand Up @@ -59,6 +60,7 @@ type API struct {
ipClient *armnetwork.PublicIPAddressesClient
intClient *armnetwork.InterfacesClient
accClient *armstorage.AccountsClient
msiClient *armmsi.UserAssignedIdentitiesClient
Opts *Options
}

Expand Down Expand Up @@ -194,6 +196,12 @@ func (a *API) SetupClients() error {
}
a.accClient = scf.NewAccountsClient()

mcf, err := armmsi.NewClientFactory(a.subID, a.creds, opts)
if err != nil {
return err
}
a.msiClient = mcf.NewUserAssignedIdentitiesClient()

return nil
}

Expand Down Expand Up @@ -302,3 +310,40 @@ func (a *API) GC(gracePeriod time.Duration) error {

return nil
}

// FindManagedIdentityID searches for a managed identity by name across the subscription
// and returns its resource ID if found
func (a *API) FindManagedIdentityID(identityName string) (string, error) {
ctx := context.TODO()

// Use NewListBySubscriptionPager to search across the entire subscription
pager := a.msiClient.NewListBySubscriptionPager(nil)

for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return "", fmt.Errorf("failed to list managed identities: %v", err)
}

// Check each identity for a name match
for _, identity := range page.Value {
if identity.Name != nil && *identity.Name == identityName {
if identity.ID == nil || *identity.ID == "" {
continue
}

// Extract resource group name from the ID for logging
idParts := strings.Split(*identity.ID, "/")
var resourceGroup string
if len(idParts) >= 5 {
resourceGroup = idParts[4]
}

plog.Infof("Found managed identity %s in resource group %s", identityName, resourceGroup)
return *identity.ID, nil
}
}
}

return "", fmt.Errorf("managed identity %q was not found in the subscription", identityName)
}
83 changes: 42 additions & 41 deletions platform/api/azure/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"encoding/base64"
"fmt"
"io"
"regexp"
"net/http"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
Expand Down Expand Up @@ -53,7 +53,7 @@ func (a *API) getVMRG(rg string) string {
return vmrg
}

func (a *API) getVMParameters(name, sshkey, storageAccountURI string, userdata *conf.Conf, ip *armnetwork.PublicIPAddress, nic *armnetwork.Interface) armcompute.VirtualMachine {
func (a *API) getVMParameters(name, sshkey string, userdata *conf.Conf, ip *armnetwork.PublicIPAddress, nic *armnetwork.Interface, managedIdentityID string) armcompute.VirtualMachine {
osProfile := armcompute.OSProfile{
AdminUsername: to.Ptr("core"),
ComputerName: &name,
Expand Down Expand Up @@ -113,6 +113,8 @@ func (a *API) getVMParameters(name, sshkey, storageAccountURI string, userdata *
plog.Warningf("failed to get image info: %v; continuing", err)
}
}

// Set up the VM configuration
vm := armcompute.VirtualMachine{
Name: &name,
Location: &a.Opts.Location,
Expand Down Expand Up @@ -148,22 +150,21 @@ func (a *API) getVMParameters(name, sshkey, storageAccountURI string, userdata *
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.Ptr(true),
StorageURI: &storageAccountURI,
Enabled: to.Ptr(true),
},
},
},
}

// Configure disk controller if specified
switch a.Opts.DiskController {
case "nvme":
vm.Properties.StorageProfile.DiskControllerType = to.Ptr(armcompute.DiskControllerTypesNVMe)
case "scsi":
vm.Properties.StorageProfile.DiskControllerType = to.Ptr(armcompute.DiskControllerTypesSCSI)
}

// I don't think it would be an issue to have empty user-data set but better
// to be safe than sorry.
// Configure user data or custom data
if ud != "" {
if a.Opts.UseUserData && userdata.IsIgnition() {
plog.Infof("using user-data")
Expand All @@ -174,15 +175,29 @@ func (a *API) getVMParameters(name, sshkey, storageAccountURI string, userdata *
}
}

// Configure availability set if specified
availabilitySetID := a.getAvset()
if availabilitySetID != "" {
vm.Properties.AvailabilitySet = &armcompute.SubResource{ID: &availabilitySetID}
}

// Configure managed identity if specified
if managedIdentityID != "" {
plog.Infof("Assigning managed identity to VM (using pre-looked-up ID)")

// Configure the VM with the user assigned managed identity
vm.Identity = &armcompute.VirtualMachineIdentity{
Type: to.Ptr(armcompute.ResourceIdentityTypeUserAssigned),
UserAssignedIdentities: map[string]*armcompute.UserAssignedIdentitiesValue{
managedIdentityID: {},
},
}
}

return vm
}

func (a *API) CreateInstance(name, sshkey, resourceGroup, storageAccount string, userdata *conf.Conf, network Network) (*Machine, error) {
func (a *API) CreateInstance(name, sshkey, resourceGroup string, userdata *conf.Conf, network Network, managedIdentityID string) (*Machine, error) {
// only VMs are created in the user supplied resource group, kola still manages a resource group
// for the gallery and storage account.
vmResourceGroup := a.getVMRG(resourceGroup)
Expand All @@ -204,7 +219,8 @@ func (a *API) CreateInstance(name, sshkey, resourceGroup, storageAccount string,
return nil, fmt.Errorf("couldn't get NIC name")
}

vmParams := a.getVMParameters(name, sshkey, fmt.Sprintf("https://%s.blob.core.windows.net/", storageAccount), userdata, ip, nic)
// Pass the managedIdentityID to getVMParameters
vmParams := a.getVMParameters(name, sshkey, userdata, ip, nic, managedIdentityID)
plog.Infof("Creating Instance %s", name)

clean := func() {
Expand Down Expand Up @@ -284,53 +300,38 @@ func (a *API) TerminateInstance(machine *Machine, resourceGroup string) error {
return err
}

func (a *API) GetConsoleOutput(name, resourceGroup, storageAccount string) ([]byte, error) {
func (a *API) GetConsoleOutput(name, resourceGroup string) ([]byte, error) {
vmResourceGroup := a.getVMRG(resourceGroup)
vm, err := a.compClient.Get(context.TODO(), vmResourceGroup, name, &armcompute.VirtualMachinesClientGetOptions{
Expand: to.Ptr(armcompute.InstanceViewTypesInstanceView),
})
param := &armcompute.VirtualMachinesClientRetrieveBootDiagnosticsDataOptions{
SasURIExpirationTimeInMinutes: to.Ptr[int32](5),
}
resp, err := a.compClient.RetrieveBootDiagnosticsData(context.TODO(), vmResourceGroup, name, param)
if err != nil {
return nil, fmt.Errorf("could not get VM: %v", err)
}

consoleURI := vm.Properties.InstanceView.BootDiagnostics.SerialConsoleLogBlobURI
if consoleURI == nil {
if resp.SerialConsoleLogBlobURI == nil {
return nil, fmt.Errorf("serial console URI is nil")
}

// Only the full URI to the logs are present in the virtual machine
// properties. Parse out the container & file name to use the GetBlob
// API call directly.
uri := []byte(*consoleURI)
containerPat := regexp.MustCompile(`bootdiagnostics-[a-z0-9\-]+`)
container := string(containerPat.Find(uri))
if container == "" {
return nil, fmt.Errorf("could not find container name in URI: %q", *consoleURI)
}
namePat := regexp.MustCompile(`[a-z0-9\-\.]+.serialconsole.log`)
blobname := string(namePat.Find(uri))
if blobname == "" {
return nil, fmt.Errorf("could not find blob name in URI: %q", *consoleURI)
}

client, err := a.GetBlobServiceClient(storageAccount)
if err != nil {
return nil, err
}
var data io.ReadCloser
var output []byte
err = util.Retry(6, 10*time.Second, func() error {
data, err = GetBlob(client, container, blobname)
reply, err := http.Get(*resp.SerialConsoleLogBlobURI)
if err != nil {
return fmt.Errorf("could not get blob for container %q, blobname %q: %v", container, blobname, err)
return fmt.Errorf("could not GET console output: %v", err)
}
body := reply.Body
defer body.Close()
if reply.StatusCode != 200 {
return fmt.Errorf("unexpected status code: %v", reply.StatusCode)
}
if data == nil {
return fmt.Errorf("empty data while getting blob for container %q, blobname %q", container, blobname)
output, err = io.ReadAll(body)
if err != nil {
return fmt.Errorf("could not read console output: %v", err)
}
return nil
})
if err != nil {
return nil, err
}

return io.ReadAll(data)
return output, nil
}
2 changes: 2 additions & 0 deletions platform/api/azure/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@ type Options struct {
ResourceGroup string
// AvailabilitySet is an existing availability set to deploy the instance in.
AvailabilitySet string
// VMIdentity is the name of a managed identity to assign to the VM.
VMIdentity string
}
2 changes: 1 addition & 1 deletion platform/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (bc *BaseCluster) RenderUserData(userdata *conf.UserData, ignitionVars map[
}
}

if bc.bf.AdditionalSshKeys != nil && *bc.bf.AdditionalSshKeys != nil {
if bc.bf.AdditionalSshKeys != nil && *bc.bf.AdditionalSshKeys != nil && !bc.rconf.NoSSHKeyInUserData {
userdata = conf.AddSSHKeys(userdata, bc.bf.AdditionalSshKeys)
}

Expand Down
13 changes: 7 additions & 6 deletions platform/machine/azure/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import (

type cluster struct {
*platform.BaseCluster
flight *flight
sshKey string
ResourceGroup string
StorageAccount string
Network azure.Network
flight *flight
sshKey string
ResourceGroup string
Network azure.Network
ManagedIdentityID string // Add managed identity ID field to cluster struct
}

func (ac *cluster) vmname() string {
Expand All @@ -48,7 +48,8 @@ func (ac *cluster) NewMachine(userdata *conf.UserData) (platform.Machine, error)
return nil, err
}

instance, err := ac.flight.Api.CreateInstance(ac.vmname(), ac.sshKey, ac.ResourceGroup, ac.StorageAccount, conf, ac.Network)
// Pass the managed identity ID to the CreateInstance method
instance, err := ac.flight.Api.CreateInstance(ac.vmname(), ac.sshKey, ac.ResourceGroup, conf, ac.Network, ac.ManagedIdentityID)
if err != nil {
return nil, err
}
Expand Down
22 changes: 14 additions & 8 deletions platform/machine/azure/flight.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type flight struct {
ImageResourceGroup string
ImageStorageAccount string
Network azure.Network
ManagedIdentityID string // Store the managed identity resource ID
}

// NewFlight creates an instance of a Flight suitable for spawning
Expand Down Expand Up @@ -79,6 +80,16 @@ func NewFlight(opts *azure.Options) (platform.Flight, error) {
return nil, err
}

// If a managed identity is specified, look it up across all resource groups
// and fail fast if it can't be found
if opts.VMIdentity != "" {
plog.Infof("Looking up managed identity %q", opts.VMIdentity)
af.ManagedIdentityID, err = api.FindManagedIdentityID(opts.VMIdentity)
if err != nil {
return nil, err
}
}

if opts.BlobURL != "" || opts.ImageFile != "" {
imageName := fmt.Sprintf("%v", time.Now().UnixNano())
blobName := imageName + ".vhd"
Expand Down Expand Up @@ -152,8 +163,9 @@ func (af *flight) NewCluster(rconf *platform.RuntimeConfig) (platform.Cluster, e
}

ac := &cluster{
BaseCluster: bc,
flight: af,
BaseCluster: bc,
flight: af,
ManagedIdentityID: af.ManagedIdentityID,
}

if !rconf.NoSSHKeyInMetadata {
Expand All @@ -164,19 +176,13 @@ func (af *flight) NewCluster(rconf *platform.RuntimeConfig) (platform.Cluster, e

if af.ImageResourceGroup != "" && af.ImageStorageAccount != "" {
ac.ResourceGroup = af.ImageResourceGroup
ac.StorageAccount = af.ImageStorageAccount
ac.Network = af.Network
} else {
ac.ResourceGroup, err = af.Api.CreateResourceGroup("kola-cluster")
if err != nil {
return nil, err
}

ac.StorageAccount, err = af.Api.CreateStorageAccount(ac.ResourceGroup)
if err != nil {
return nil, err
}

ac.Network, err = af.Api.PrepareNetworkResources(ac.ResourceGroup)
if err != nil {
ac.Destroy()
Expand Down
2 changes: 1 addition & 1 deletion platform/machine/azure/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (am *machine) ConsoleOutput() string {

func (am *machine) saveConsole() error {
var err error
am.console, err = am.cluster.flight.Api.GetConsoleOutput(am.ID(), am.ResourceGroup(), am.cluster.StorageAccount)
am.console, err = am.cluster.flight.Api.GetConsoleOutput(am.ID(), am.ResourceGroup())
if err != nil {
return err
}
Expand Down
Loading
Loading