Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't retry SSH connection when auth fails #818

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion internal/am/upload_transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"

"github.com/go-logr/logr"
"go.artefactual.dev/tools/temporal"

"github.com/artefactual-sdps/enduro/internal/sftp"
)
Expand Down Expand Up @@ -46,7 +47,14 @@ func (a *UploadTransferActivity) Execute(ctx context.Context, params *UploadTran
filename := filepath.Base(params.SourcePath)
bytes, path, err := a.client.Upload(ctx, src, filename)
if err != nil {
return nil, fmt.Errorf("%s: %v", UploadTransferActivityName, err)
e := fmt.Errorf("%s: %v", UploadTransferActivityName, err)

switch err.(type) {
case *sftp.AuthError:
return nil, temporal.NewNonRetryableError(e)
default:
return nil, e
}
}

return &UploadTransferActivityResult{
Expand Down
48 changes: 37 additions & 11 deletions internal/am/upload_transfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import (

"github.com/go-logr/logr"
"go.artefactual.dev/tools/mockutil"
"go.artefactual.dev/tools/temporal"
temporalsdk_activity "go.temporal.io/sdk/activity"
temporalsdk_testsuite "go.temporal.io/sdk/testsuite"
"go.uber.org/mock/gomock"
"gotest.tools/v3/assert"
tfs "gotest.tools/v3/fs"

"github.com/artefactual-sdps/enduro/internal/am"
"github.com/artefactual-sdps/enduro/internal/sftp"
sftp_fake "github.com/artefactual-sdps/enduro/internal/sftp/fake"
)

Expand All @@ -25,11 +27,12 @@ func TestUploadTransferActivity(t *testing.T) {
)

type test struct {
name string
params am.UploadTransferActivityParams
want am.UploadTransferActivityResult
recorder func(*sftp_fake.MockClientMockRecorder)
errMsg string
name string
params am.UploadTransferActivityParams
recorder func(*sftp_fake.MockClientMockRecorder)
want am.UploadTransferActivityResult
wantErr string
wantNonRetryErr bool
}
for _, tt := range []test{
{
Expand All @@ -56,10 +59,10 @@ func TestUploadTransferActivity(t *testing.T) {
params: am.UploadTransferActivityParams{
SourcePath: td.Join("missing"),
},
errMsg: fmt.Sprintf("UploadTransferActivity: open %s: no such file or directory", td.Join("missing")),
wantErr: fmt.Sprintf("activity error (type: UploadTransferActivity, scheduledEventID: 0, startedEventID: 0, identity: ): UploadTransferActivity: open %s: no such file or directory", td.Join("missing")),
},
{
name: "Errors when upload fails",
name: "Retryable error when SSH connection fails",
params: am.UploadTransferActivityParams{
SourcePath: td.Join(filename),
},
Expand All @@ -72,10 +75,32 @@ func TestUploadTransferActivity(t *testing.T) {
).Return(
0,
"",
errors.New("SSH: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused"),
errors.New("ssh: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused"),
)
},
errMsg: "UploadTransferActivity: SSH: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused",
wantErr: "activity error (type: UploadTransferActivity, scheduledEventID: 0, startedEventID: 0, identity: ): UploadTransferActivity: ssh: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused",
},
{
name: "Non-retryable error when authentication fails",
params: am.UploadTransferActivityParams{
SourcePath: td.Join(filename),
},
recorder: func(m *sftp_fake.MockClientMockRecorder) {
var t *os.File
m.Upload(
mockutil.Context(),
gomock.AssignableToTypeOf(t),
filename,
).Return(
0,
"",
sftp.NewAuthError(
errors.New("ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain"),
),
)
},
wantErr: "activity error (type: UploadTransferActivity, scheduledEventID: 0, startedEventID: 0, identity: ): UploadTransferActivity: ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain",
wantNonRetryErr: true,
},
} {
tt := tt
Expand All @@ -98,8 +123,9 @@ func TestUploadTransferActivity(t *testing.T) {
)

fut, err := env.ExecuteActivity(am.UploadTransferActivityName, tt.params)
if tt.errMsg != "" {
assert.ErrorContains(t, err, tt.errMsg)
if tt.wantErr != "" {
assert.Error(t, err, tt.wantErr)
assert.Assert(t, temporal.NonRetryableError(err) == tt.wantNonRetryErr)
return
}

Expand Down
16 changes: 16 additions & 0 deletions internal/sftp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ import (
"io"
)

type AuthError struct {
err error
}

func (e *AuthError) Error() string {
return e.err.Error()
}

func (e *AuthError) Unwrap() error {
return e.err
}

func NewAuthError(e error) *AuthError {
return &AuthError{err: e}
}
Copy link
Contributor

@sevein sevein Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A possible alternative:

type AuthError struct {
	Reason string
}

func (e *AuthError) Error() string {
	return fmt.Sprintf("auth: %s", e.Reason)
}

func NewAuthError(err error) error {
	return &AuthError{Reason: err.Error()}
}

This version:

  • Doesn't use Unwrap because we don't need to give access to the underlying error.
  • Uses the error type in the constructor signature, see https://go.dev/doc/faq#nil_error:

    It's a good idea for functions that return errors always to use the error type in their signature rather than a concrete type such as *MyError, to help guarantee the error is created correctly. As an example, os.Open returns an error even though, if not nil, it's always of concrete type *os.PathError.

I think that we'd want the function to return a concrete type (*AuthError) if we were doing some performance-focused work, e.g. https://github.com/connectrpc/connect-go/blob/d88758dc0e89170db922ecd20f16cec57662ec23/error.go#L117-L128 could be a good example where the design was driven by performance objectives, like avoiding memory allocations. This is a guess though, not entirely sure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sevein does Reason need to be exported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exporting the Reason field does have the nice benefit of being able to set an arbitrary string without having to create an error first, e.g.

e := &sftp.AuthError{Reason: "bad password"}

vs.

e := sftp.NewAuthError(errors.New("bad password"))


// A Client manages the transmission of data over SFTP.
//
// Implementations of the Client interface handle the connection details,
Expand Down
2 changes: 1 addition & 1 deletion internal/sftp/goclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (c *GoClient) dial(ctx context.Context) (*connection, error) {

conn.ssh, err = sshConnect(ctx, c.logger, c.cfg)
if err != nil {
return nil, fmt.Errorf("SSH: %v", err)
return nil, err
}

conn.sftp, err = sftp.NewClient(conn.ssh)
Expand Down
29 changes: 20 additions & 9 deletions internal/sftp/goclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -174,7 +176,7 @@ func TestUpload(t *testing.T) {
cfg sftp.Config
params params
want results
wantErr string
wantErr error
}
for _, tc := range []test{
{
Expand Down Expand Up @@ -231,7 +233,9 @@ func TestUpload(t *testing.T) {
src: strings.NewReader("Testing 1-2-3"),
dest: "test.txt",
},
wantErr: "SSH: parse private key with passphrase: x509: decryption password incorrect",
wantErr: sftp.NewAuthError(
errors.New("ssh: parse private key with passphrase: x509: decryption password incorrect"),
),
},
{
name: "Errors when the SFTP server isn't there",
Expand All @@ -247,8 +251,8 @@ func TestUpload(t *testing.T) {
src: strings.NewReader("Testing 1-2-3"),
dest: "test.txt",
},
wantErr: fmt.Sprintf(
"SSH: connect: dial tcp %s:%s: connect: connection refused",
wantErr: fmt.Errorf(
"ssh: connect: dial tcp %s:%s: connect: connection refused",
badHost, badPort,
),
},
Expand All @@ -262,7 +266,9 @@ func TestUpload(t *testing.T) {
Path: "./testdata/clientkeys/test_unk_ed25519",
},
},
wantErr: "SSH: connect: ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain",
wantErr: sftp.NewAuthError(
errors.New("ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain"),
),
},
{
name: "Errors when the host key is not in known_hosts",
Expand All @@ -274,7 +280,9 @@ func TestUpload(t *testing.T) {
Path: "./testdata/clientkeys/test_ed25519",
},
},
wantErr: "SSH: connect: ssh: handshake failed: knownhosts: key is unknown",
wantErr: sftp.NewAuthError(
errors.New("ssh: handshake failed: knownhosts: key is unknown"),
),
},
{
name: "Errors when the known_hosts file doesn't exist",
Expand All @@ -286,7 +294,9 @@ func TestUpload(t *testing.T) {
Path: "./testdata/clientkeys/test_ed25519",
},
},
wantErr: "SSH: parse known_hosts: open testdata/missing: no such file or directory",
wantErr: sftp.NewAuthError(
errors.New("ssh: parse known_hosts: open testdata/missing: no such file or directory"),
),
},
} {
tc := tc
Expand All @@ -300,8 +310,9 @@ func TestUpload(t *testing.T) {
sftpc := sftp.NewGoClient(logr.Discard(), tc.cfg)
bytes, remotePath, err := sftpc.Upload(context.Background(), tc.params.src, tc.params.dest)

if tc.wantErr != "" {
assert.Error(t, err, tc.wantErr)
if tc.wantErr != nil {
assert.Error(t, err, tc.wantErr.Error())
assert.Assert(t, reflect.TypeOf(err) == reflect.TypeOf(tc.wantErr))
return
}

Expand Down
23 changes: 14 additions & 9 deletions internal/sftp/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"os"
"path/filepath"
"strings"
"time"

"github.com/go-logr/logr"
Expand All @@ -17,32 +18,32 @@ import (
// returns a client connection.
//
// Only private key authentication is currently supported, with or without a
// passphrase.
// passphrase.SSH: %v",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this was accidentally changed!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! 🤦

func sshConnect(ctx context.Context, logger logr.Logger, cfg Config) (*ssh.Client, error) {
// Load private key for authentication.
keyBytes, err := os.ReadFile(filepath.Clean(cfg.PrivateKey.Path)) // #nosec G304 -- File data is validated below
if err != nil {
return nil, fmt.Errorf("read private key: %v", err)
return nil, NewAuthError(fmt.Errorf("ssh: read private key: %v", err))
}

// Create a signer from the private key, with or without a passphrase.
var signer ssh.Signer
if cfg.PrivateKey.Passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(keyBytes, []byte(cfg.PrivateKey.Passphrase))
if err != nil {
return nil, fmt.Errorf("parse private key with passphrase: %v", err)
return nil, NewAuthError(fmt.Errorf("ssh: parse private key with passphrase: %v", err))
}
} else {
signer, err = ssh.ParsePrivateKey(keyBytes)
if err != nil {
return nil, fmt.Errorf("parse private key: %v", err)
return nil, NewAuthError(fmt.Errorf("ssh: parse private key: %v", err))
}
}

// Check that the host key is in the client's known_hosts file.
hostcallback, err := knownhosts.New(filepath.Clean(cfg.KnownHostsFile))
if err != nil {
return nil, fmt.Errorf("parse known_hosts: %v", err)
return nil, NewAuthError(fmt.Errorf("ssh: parse known_hosts: %v", err))
}

// Configure the SSH client.
Expand All @@ -61,14 +62,18 @@ func sshConnect(ctx context.Context, logger logr.Logger, cfg Config) (*ssh.Clien
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
logger.V(2).Info("SSH dial failed", "address", address, "user", cfg.User)
return nil, fmt.Errorf("connect: %v", err)
logger.V(2).Info("ssh: dial failed", "address", address, "user", cfg.User)
return nil, fmt.Errorf("ssh: connect: %v", err)
}

sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, sshConfig)
if err != nil {
logger.V(2).Info("SSH dial failed", "address", address, "user", cfg.User)
return nil, fmt.Errorf("connect: %v", err)
if strings.Contains(err.Error(), "ssh: unable to authenticate") || strings.Contains(err.Error(), "knownhosts: key is unknown") {
logger.V(2).Info("ssh: authentication failed", "address", address, "user", cfg.User)
return nil, NewAuthError(err)
}
logger.V(2).Info("ssh: failed to connect", "address", address, "user", cfg.User)
return nil, err
}

return ssh.NewClient(sshConn, chans, reqs), nil
Expand Down