Skip to content

Commit

Permalink
Merge pull request #91 from justinmerrell/port-ssh
Browse files Browse the repository at this point in the history
Port SSH Commands from Python and Minor Refactors
  • Loading branch information
DireLines authored Feb 9, 2024
2 parents 919ebaf + 4767fe2 commit 96e8b94
Show file tree
Hide file tree
Showing 46 changed files with 732 additions and 366 deletions.
12 changes: 11 additions & 1 deletion api/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"runtime"
Expand All @@ -19,17 +21,25 @@ type Input struct {
func Query(input Input) (res *http.Response, err error) {
jsonValue, err := json.Marshal(input)
if err != nil {
return
return nil, err
}

apiUrl := os.Getenv("RUNPOD_API_URL")
if apiUrl == "" {
apiUrl = viper.GetString("apiUrl")
}

apiKey := os.Getenv("RUNPOD_API_KEY")
if apiKey == "" {
apiKey = viper.GetString("apiKey")
}

// Check if the API key is present
if apiKey == "" {
fmt.Println("API key not found")
return nil, errors.New("API key not found")
}

req, err := http.NewRequest("POST", apiUrl+"?api_key="+apiKey, bytes.NewBuffer(jsonValue))
if err != nil {
return
Expand Down
96 changes: 68 additions & 28 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ package api

import (
"encoding/json"
"errors"
"fmt"
"io"
"strings"

"golang.org/x/crypto/ssh"
)

func GetPublicSSHKeys() (keys string, err error) {
type SSHKey struct {
Name string `json:"name"`
Type string `json:"type"`
Key string `json:"key"`
Fingerprint string `json:"fingerprint"`
}

func GetPublicSSHKeys() (string, []SSHKey, error) {
input := Input{
Query: `
query myself {
Expand All @@ -19,48 +27,79 @@ func GetPublicSSHKeys() (keys string, err error) {
}
`,
}

res, err := Query(input)
if err != nil {
return "", err
return "", nil, err
}
defer res.Body.Close()

if res.StatusCode != 200 {
err = fmt.Errorf("statuscode %d", res.StatusCode)
return
return "", nil, fmt.Errorf("unexpected status code: %d", res.StatusCode)
}
defer res.Body.Close()

rawData, err := io.ReadAll(res.Body)
if err != nil {
return "", err
return "", nil, fmt.Errorf("failed to read response body: %w", err)
}
data := &UserOut{}
if err = json.Unmarshal(rawData, data); err != nil {
return "", err

var data UserOut
if err := json.Unmarshal(rawData, &data); err != nil {
return "", nil, fmt.Errorf("JSON unmarshal error: %w", err)
}

if len(data.Errors) > 0 {
err = errors.New(data.Errors[0].Message)
return "", err
return "", nil, fmt.Errorf("API error: %s", data.Errors[0].Message)
}
if data == nil || data.Data == nil || data.Data.Myself == nil {
err = fmt.Errorf("data is nil: %s", string(rawData))
return "", err

if data.Data == nil || data.Data.Myself == nil {
return "", nil, fmt.Errorf("nil data received: %s", string(rawData))
}

// Parse the public key string into a list of SSHKey structs
var keys []SSHKey
keyStrings := strings.Split(data.Data.Myself.PubKey, "\n")
for _, keyString := range keyStrings {
if keyString == "" {
continue
}

pubKey, name, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
if err != nil {
continue // Skip keys that can't be parsed
}

keys = append(keys, SSHKey{
Name: name,
Type: pubKey.Type(),
Key: string(ssh.MarshalAuthorizedKey(pubKey)),
Fingerprint: ssh.FingerprintSHA256(pubKey),
})
}
return data.Data.Myself.PubKey, nil

return data.Data.Myself.PubKey, keys, nil
}

func AddPublicSSHKey(key []byte) error {
//pull existing pubKey
existingKeys, err := GetPublicSSHKeys()
rawKeys, existingKeys, err := GetPublicSSHKeys()
if err != nil {
return err
return fmt.Errorf("failed to get existing SSH keys: %w", err)
}

keyStr := string(key)
//check for key present
if strings.Contains(existingKeys, keyStr) {
return nil
for _, k := range existingKeys {
if strings.TrimSpace(k.Key) == strings.TrimSpace(keyStr) {
return nil
}
}
// concat key onto pubKey
newKeys := existingKeys + "\n\n" + keyStr
// set new pubKey

// Concatenate the new key onto the existing keys, separated by a newline
newKeys := strings.TrimSpace(rawKeys)
if newKeys != "" {
newKeys += "\n\n"
}
newKeys += strings.TrimSpace(keyStr)

input := Input{
Query: `
mutation Mutation($input: UpdateUserSettingsInput) {
Expand All @@ -71,9 +110,10 @@ func AddPublicSSHKey(key []byte) error {
`,
Variables: map[string]interface{}{"input": map[string]interface{}{"pubKey": newKeys}},
}
_, err = Query(input)
if err != nil {
return err

if _, err = Query(input); err != nil {
return fmt.Errorf("failed to update SSH keys: %w", err)
}

return nil
}
92 changes: 22 additions & 70 deletions cmd/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,100 +2,52 @@ package config

import (
"cli/api"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"cli/cmd/ssh"
"fmt"
"os"
"path/filepath"

"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/crypto/ssh"
)

var ConfigFile string
var apiKey string
var apiUrl string
var (
ConfigFile string
apiKey string
apiUrl string
)

var ConfigCmd = &cobra.Command{
Use: "config",
Short: "CLI Config",
Long: "RunPod CLI Config Settings",
Run: func(c *cobra.Command, args []string) {
err := viper.WriteConfig()
cobra.CheckErr(err)
fmt.Println("saved apiKey into config file: " + ConfigFile)
home, err := os.UserHomeDir()
if err := viper.WriteConfig(); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
return
}
fmt.Println("Configuration saved to file:", viper.ConfigFileUsed())

publicKey, err := ssh.GenerateSSHKeyPair("RunPod-Key-Go")
if err != nil {
fmt.Println("couldn't get user home dir path")
fmt.Fprintf(os.Stderr, "Failed to generate SSH key: %v\n", err)
return
}
sshFolderPath := filepath.Join(home, ".runpod", "ssh")
os.MkdirAll(sshFolderPath, os.ModePerm)
privateSshPath := filepath.Join(sshFolderPath, "RunPod-Key-Go")
publicSshPath := filepath.Join(sshFolderPath, "RunPod-Key-Go.pub")
publicKey, _ := os.ReadFile(publicSshPath)
if _, err := os.Stat(privateSshPath); errors.Is(err, os.ErrNotExist) {
publicKey = makeRSAKey(privateSshPath)

if err := api.AddPublicSSHKey(publicKey); err != nil {
fmt.Fprintf(os.Stderr, "Failed to add the SSH key: %v\n", err)
return
}
api.AddPublicSSHKey(publicKey)
fmt.Println("SSH key added successfully.")
},
}

func init() {
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "runpod api key")
ConfigCmd.MarkFlagRequired("apiKey")
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "RunPod API key")
viper.BindPFlag("apiKey", ConfigCmd.Flags().Lookup("apiKey")) //nolint
viper.SetDefault("apiKey", "")

ConfigCmd.Flags().StringVar(&apiUrl, "apiUrl", "", "runpod api url")
ConfigCmd.Flags().StringVar(&apiUrl, "apiUrl", "https://api.runpod.io/graphql", "RunPod API URL")
viper.BindPFlag("apiUrl", ConfigCmd.Flags().Lookup("apiUrl")) //nolint
viper.SetDefault("apiUrl", "https://api.runpod.io/graphql")
}

func makeRSAKey(filename string) []byte {
bitSize := 2048

// Generate RSA key.
key, err := rsa.GenerateKey(rand.Reader, bitSize)
if err != nil {
panic(err)
}

// Extract public component.
pub := key.PublicKey

// Encode private key to PKCS#1 ASN.1 PEM.
keyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)

// generate and write public key
publicKey, err := ssh.NewPublicKey(&pub)
if err != nil {
fmt.Println("err in NewPublicKey")
fmt.Println(err)
}
pubBytes := ssh.MarshalAuthorizedKey(publicKey)
pubBytes = append(pubBytes, []byte(" "+filename)...)

// Write private key to file.
if err := os.WriteFile(filename, keyPEM, 0600); err != nil {
fmt.Println("err writing priv")
panic(err)
}

// Write public key to file.
if err := os.WriteFile(filename+".pub", pubBytes, 0600); err != nil {
fmt.Println("err writing pub")
panic(err)
}
fmt.Println("saved new SSH public key into", filename+".pub")
return pubBytes
ConfigCmd.MarkFlagRequired("apiKey")
}
18 changes: 18 additions & 0 deletions cmd/exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package cmd

import (
"cli/cmd/exec"

"github.com/spf13/cobra"
)

// execCmd represents the base command for executing commands in a pod
var execCmd = &cobra.Command{
Use: "exec",
Short: "Execute commands in a pod",
Long: `Execute a local file remotely in a pod.`,
}

func init() {
execCmd.AddCommand(exec.RemotePythonCmd)
}
39 changes: 39 additions & 0 deletions cmd/exec/commands.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package exec

import (
"fmt"
"os"

"github.com/spf13/cobra"
)

var RemotePythonCmd = &cobra.Command{
Use: "python [file]",
Short: "Runs a remote Python shell",
Long: `Runs a remote Python shell with a local script file.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
podID, _ := cmd.Flags().GetString("pod_id")
file := args[0]

// Default to the session pod if no pod_id is provided
// if podID == "" {
// var err error
// podID, err = api.GetSessionPod()
// if err != nil {
// fmt.Fprintf(os.Stderr, "Error retrieving session pod: %v\n", err)
// return
// }
// }

fmt.Println("Running remote Python shell...")
if err := PythonOverSSH(podID, file); err != nil {
fmt.Fprintf(os.Stderr, "Error executing Python over SSH: %v\n", err)
}
},
}

func init() {
RemotePythonCmd.Flags().String("pod_id", "", "The ID of the pod to run the command on.")
RemotePythonCmd.MarkFlagRequired("file")
}
25 changes: 25 additions & 0 deletions cmd/exec/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package exec

import (
"cli/cmd/project"
"fmt"
)

func PythonOverSSH(podID string, file string) error {
sshConn, err := project.PodSSHConnection(podID)
if err != nil {
return fmt.Errorf("getting SSH connection: %w", err)
}

// Copy the file to the pod using Rsync
if err := sshConn.Rsync(file, "/tmp/"+file, false); err != nil {
return fmt.Errorf("copying file to pod: %w", err)
}

// Run the file on the pod
if err := sshConn.RunCommand("python3.11 /tmp/" + file); err != nil {
return fmt.Errorf("running Python command: %w", err)
}

return nil
}
Loading

0 comments on commit 96e8b94

Please sign in to comment.