Skip to content

Commit

Permalink
Better auth method cycling
Browse files Browse the repository at this point in the history
  • Loading branch information
bomoko committed Nov 29, 2023
1 parent df7d500 commit 9bddffd
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 48 deletions.
15 changes: 0 additions & 15 deletions cmd/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,6 @@ func syncCommandRun(cmd *cobra.Command, args []string) {
sshKey = sshConfig.PrivateKey
}

if sshKey == "" { //let's try guess it from the OS
userPath, err := os.UserHomeDir()
if err != nil {
utils.LogWarning("No ssh key given and no home directory available", os.Stdout)
}
potentialKey := fmt.Sprintf("%s/.ssh/id_rsa", userPath)
_, err = os.Stat(potentialKey)
if err != nil {
if SSHSkipAgent == true {
utils.LogFatalError(fmt.Sprintf("Unable to find key at fallback location '%v' - please provide an ssh key with the `--ssh-key` option, or use the ssh-agent", potentialKey), os.Stderr)
}
}
sshKey = potentialKey
}

sshVerbose := SSHVerbose
if sshConfig.Verbose && !sshVerbose {
sshVerbose = sshConfig.Verbose
Expand Down
176 changes: 143 additions & 33 deletions utils/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ package utils

import (
"bytes"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"net"
"os"
"os/exec"
"path/filepath"
)

const ShellToUse = "sh"

var validAuthMethod *ssh.AuthMethod

func Shellout(command string) (error, string, string) {
var stdout bytes.Buffer
var stderr bytes.Buffer
Expand All @@ -21,44 +26,65 @@ func Shellout(command string) (error, string, string) {
return err, stdout.String(), stderr.String()
}

func RemoteShellout(command string, remoteUser string, remoteHost string, remotePort string, privateKeyfile string, skipSshAgent bool) (error, string) {

sshAuthSock, present := os.LookupEnv("SSH_AUTH_SOCK")
skipAgent := !present || skipSshAgent

var authMethods []ssh.AuthMethod

if skipAgent != true {
// Connect to SSH agent to ask for unencrypted private keys
if sshAgentConn, err := net.Dial("unix", sshAuthSock); err == nil {
sshAgent := agent.NewClient(sshAgentConn)
keys, _ := sshAgent.List()
if len(keys) > 0 {
agentAuthmethods := ssh.PublicKeysCallback(sshAgent.Signers)
authMethods = append(authMethods, agentAuthmethods)
}
}
} else {
LogDebugInfo("Skipping ssh agent", os.Stdout)
}

privateKeyBytes, err := os.ReadFile(privateKeyfile)
func getAuthMethodFromPrivateKey(filename string) (ssh.AuthMethod, error) {
privateKeyBytes, err := os.ReadFile(filename)

// if there are authMethods already, let's keep going
if err != nil && len(authMethods) == 0 {
return err, ""
if err != nil {
return nil, err
}

if len(privateKeyBytes) > 0 {
// Parse the private key
signer, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
return err, ""
return nil, err
}

// SSH client configuration
authKeys := ssh.PublicKeys(signer)
authMethods = append(authMethods, authKeys)
return authKeys, nil

}
return nil, errors.New(fmt.Sprint("No data in privateKey: ", filename))
}

func findSSHKeyFiles(directory string) ([]string, error) {
var keys []string
err := filepath.Walk(directory, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if !info.IsDir() && filepath.Ext(path) == ".pub" {
privateKeyPath := path[:len(path)-4] // remove ".pub" extension
keys = append(keys, privateKeyPath)
}
return nil
})
if err != nil {
return nil, err
}
return keys, nil
}

func isPassphraseMissingError(err error) bool {
_, ok := err.(*ssh.PassphraseMissingError)
return ok
}

func RemoteShellout(command string, remoteUser string, remoteHost string, remotePort string, privateKeyfile string, skipSshAgent bool) (error, string) {

sshAuthSock, present := os.LookupEnv("SSH_AUTH_SOCK")
skipAgent := !present || skipSshAgent

var authMethods []ssh.AuthMethod

if validAuthMethod == nil { // This makes it so that in subsequent calls, we don't have to recheck all auth methods
authMethods = getAuthmethods(skipAgent, privateKeyfile, sshAuthSock, authMethods)
}

if len(authMethods) == 0 && validAuthMethod == nil {
return errors.New("No valid authentication methods provided"), ""
}

config := &ssh.ClientConfig{
Expand All @@ -67,14 +93,39 @@ func RemoteShellout(command string, remoteUser string, remoteHost string, remote
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

// Connect to the remote server
client, err := ssh.Dial("tcp", remoteHost+":"+remotePort, config)
if err != nil {
return err, ""
var client *ssh.Client
var err error
if validAuthMethod != nil {
LogDebugInfo("Have a valid auth method", os.Stdout)
// Connect to the remote server
config.Auth = []ssh.AuthMethod{
*validAuthMethod,
}
client, err = ssh.Dial("tcp", remoteHost+":"+remotePort, config)
if err != nil {
return err, ""
}
defer client.Close()
} else {
//we need to iterate over the auth methods till we find one that works
LogDebugInfo("Trying an auth method", os.Stdout)
for _, am := range authMethods {
config.Auth = []ssh.AuthMethod{
am,
}
client, err = ssh.Dial("tcp", remoteHost+":"+remotePort, config)
if err != nil {
continue
}
validAuthMethod = &am // set the valid auth method so that future calls won't need to retry
break
}
if validAuthMethod == nil {
return errors.New("unable to find valid auth method for ssh"), ""
}
defer client.Close()
}

defer client.Close()

// Create a session
session, err := client.NewSession()
if err != nil {
Expand Down Expand Up @@ -105,3 +156,62 @@ func RemoteShellout(command string, remoteUser string, remoteHost string, remote

return nil, outputBuffer.String()
}

func getAuthmethods(skipAgent bool, privateKeyfile string, sshAuthSock string, authMethods []ssh.AuthMethod) []ssh.AuthMethod {
if skipAgent != true && privateKeyfile == "" {
// Connect to SSH agent to ask for unencrypted private keys
if sshAgentConn, err := net.Dial("unix", sshAuthSock); err == nil {
sshAgent := agent.NewClient(sshAgentConn)
keys, _ := sshAgent.List()
if len(keys) > 0 {
agentAuthmethods := ssh.PublicKeysCallback(sshAgent.Signers)
authMethods = append(authMethods, agentAuthmethods)
}
}
} else {
LogDebugInfo("Skipping ssh agent", os.Stdout)
}

if privateKeyfile == "" { //let's try guess it from the OS
userPath, err := os.UserHomeDir()
if err != nil {
LogWarning("No ssh key given and no home directory available", os.Stdout)
}

userPath = filepath.Join(userPath, ".ssh")

if _, err := os.Stat(userPath); err == nil {
files, err := findSSHKeyFiles(userPath)
if err != nil {
LogWarning(err.Error(), os.Stdout)
}
for _, f := range files {
am, err := getAuthMethodFromPrivateKey(f)
if err != nil {
switch {
case isPassphraseMissingError(err):
LogDebugInfo(fmt.Sprintf("Found a passphrase based ssh key: %v", err.Error()), os.Stdout)
default:
LogWarning(err.Error(), os.Stdout)
}
} else {
authMethods = append(authMethods, am)
}
}
} else {
LogWarning("Unable to find .ssh directory in user home", os.Stdout)
}
} else {
privateKeyFiles := []string{
privateKeyfile,
}

for _, kf := range privateKeyFiles {
am, err := getAuthMethodFromPrivateKey(kf)
if err == nil {
authMethods = append(authMethods, am)
}
}
}
return authMethods
}
41 changes: 41 additions & 0 deletions utils/shell_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package utils

import (
"reflect"
"testing"
)

func Test_findSSHKeyFiles(t *testing.T) {
type args struct {
directory string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "Run on test directory",
args: args{directory: "../test-resources/shell-tests/test_findSSHKeyFiles"},
want: []string{
"../test-resources/shell-tests/test_findSSHKeyFiles/key1",
"../test-resources/shell-tests/test_findSSHKeyFiles/key2",
"../test-resources/shell-tests/test_findSSHKeyFiles/key3",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := findSSHKeyFiles(tt.args.directory)
if (err != nil) != tt.wantErr {
t.Errorf("findSSHKeyFiles() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("findSSHKeyFiles() got = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 9bddffd

Please sign in to comment.