Skip to content

Commit

Permalink
Merge pull request #159 from Yhlong00/main
Browse files Browse the repository at this point in the history
E-1654: fix: only add key when not exist in db.
  • Loading branch information
DireLines authored Sep 9, 2024
2 parents 7feb069 + a932587 commit ae0a532
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 13 deletions.
85 changes: 72 additions & 13 deletions cmd/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package config

import (
"fmt"
"os"

"github.com/runpod/runpodctl/api"
"github.com/runpod/runpodctl/cmd/ssh"

"github.com/spf13/cobra"
"github.com/spf13/viper"
sshcrypto "golang.org/x/crypto/ssh"
)

var (
Expand All @@ -21,27 +21,86 @@ var ConfigCmd = &cobra.Command{
Use: "config",
Short: "Manage CLI configuration",
Long: "RunPod CLI Config Settings",
Run: func(c *cobra.Command, args []string) {
if err := viper.WriteConfig(); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
return
RunE: func(c *cobra.Command, args []string) error {
if err := saveConfig(); err != nil {
return fmt.Errorf("error saving config: %w", err)
}
fmt.Println("Configuration saved to file:", viper.ConfigFileUsed())

publicKey, err := ssh.GenerateSSHKeyPair("RunPod-Key-Go")
publicKey, err := getOrCreateSSHKey()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to generate SSH key: %v\n", err)
return
return fmt.Errorf("failed to get or create local SSH key: %w", err)
}

if err := api.AddPublicSSHKey(publicKey); err != nil {
fmt.Fprintf(os.Stderr, "Failed to add the SSH key: %v\n", err)
return
if err := ensureSSHKeyInCloud(publicKey); err != nil {
return fmt.Errorf("failed to update SSH key in the cloud: %w", err)
}
fmt.Println("SSH key added successfully.")

return nil
},
}

// saveConfig saves the CLI configuration to a file
func saveConfig() error {
if err := viper.WriteConfig(); err != nil {
return err
}
fmt.Println("Configuration saved to file:", viper.ConfigFileUsed())
return nil
}

// Checks for an existing local SSH key and generates a new one if not found
func getOrCreateSSHKey() ([]byte, error) {
publicKey, err := ssh.GetLocalSSHKey()
if err != nil {
return nil, fmt.Errorf("error checking for local SSH key: %w", err)
}

if publicKey == nil {
fmt.Println("No existing local SSH key found, generating a new one.")
publicKey, err = ssh.GenerateSSHKeyPair("RunPod-Key-Go")
if err != nil {
return nil, fmt.Errorf("failed to generate SSH key: %w", err)
}
fmt.Println("New SSH key pair generated.")
} else {
fmt.Println("Existing local SSH key found.")
}

return publicKey, nil
}

// ensureSSHKeyInCloud checks if the SSH key exists in the cloud and adds it if necessary
func ensureSSHKeyInCloud(publicKey []byte) error {
_, cloudKeys, err := api.GetPublicSSHKeys()
if err != nil {
return fmt.Errorf("failed to get SSH keys from the cloud: %w", err)
}

// Parse the local public key
localPubKey, _, _, _, err := sshcrypto.ParseAuthorizedKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse local public key: %w", err)
}

localFingerprint := sshcrypto.FingerprintSHA256(localPubKey)

// Check if the publicKey already exists in the cloud
for _, cloudKey := range cloudKeys {
if cloudKey.Fingerprint == localFingerprint {
fmt.Println("SSH key already exists in the cloud. No action needed.")
return nil
}
}

// If the key doesn't exist, add it
if err := api.AddPublicSSHKey(publicKey); err != nil {
return fmt.Errorf("failed to add the SSH key: %w", err)
}

fmt.Println("SSH key added successfully to the cloud.")
return nil
}

func init() {
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "RunPod API key")
viper.BindPFlag("apiKey", ConfigCmd.Flags().Lookup("apiKey")) //nolint
Expand Down
25 changes: 25 additions & 0 deletions cmd/ssh/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,28 @@ func GenerateSSHKeyPair(keyName string) ([]byte, error) {
fmt.Printf("SSH key pair generated: %s (private), %s (public)\n", privateKeyPath, publicKeyPath)
return publicKeyBytes, nil
}

func GetLocalSSHKey() ([]byte, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get user home directory: %w", err)
}

keyPath := filepath.Join(homeDir, ".runpod", "ssh", "RunPod-Key-Go.pub")

publicKey, err := os.ReadFile(keyPath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No existing key found
}
return nil, fmt.Errorf("failed to read existing public key: %w", err)
}

// Validate the key format
_, _, _, _, err = ssh.ParseAuthorizedKey(publicKey)
if err != nil {
return nil, fmt.Errorf("invalid public key format: %w", err)
}

return publicKey, nil
}

0 comments on commit ae0a532

Please sign in to comment.